diff --git a/examples/agent_creator.py b/examples/agent_creator.py index 325e7c260..3618c0608 100644 --- a/examples/agent_creator.py +++ b/examples/agent_creator.py @@ -1,22 +1,22 @@ -''' +""" Filename: MetaGPT/examples/agent_creator.py Created Date: Tuesday, September 12th 2023, 3:28:37 pm Author: garylin2099 -''' +""" import re -from metagpt.const import PROJECT_ROOT, WORKSPACE_ROOT from metagpt.actions import Action +from metagpt.const import PROJECT_ROOT, WORKSPACE_ROOT +from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message -from metagpt.logs import logger with open(PROJECT_ROOT / "examples/build_customized_agent.py", "r") as f: # use official example script to guide AgentCreator MULTI_ACTION_AGENT_CODE_EXAMPLE = f.read() -class CreateAgent(Action): +class CreateAgent(Action): PROMPT_TEMPLATE = """ ### BACKGROUND You are using an agent framework called metagpt to write agents capable of different actions, @@ -34,7 +34,6 @@ class CreateAgent(Action): """ async def run(self, example: str, instruction: str): - prompt = self.PROMPT_TEMPLATE.format(example=example, instruction=instruction) # logger.info(prompt) @@ -46,13 +45,14 @@ class CreateAgent(Action): @staticmethod def parse_code(rsp): - pattern = r'```python(.*)```' + pattern = r"```python(.*)```" match = re.search(pattern, rsp, re.DOTALL) code_text = match.group(1) if match else "" with open(WORKSPACE_ROOT / "agent_created_agent.py", "w") as f: f.write(code_text) return code_text + class AgentCreator(Role): def __init__( self, @@ -76,11 +76,11 @@ class AgentCreator(Role): return msg + if __name__ == "__main__": import asyncio async def main(): - agent_template = MULTI_ACTION_AGENT_CODE_EXAMPLE creator = AgentCreator(agent_template=agent_template) diff --git a/examples/build_customized_agent.py b/examples/build_customized_agent.py index 87d7a9c76..ef274be8b 100644 --- a/examples/build_customized_agent.py +++ b/examples/build_customized_agent.py @@ -1,21 +1,21 @@ -''' +""" Filename: MetaGPT/examples/build_customized_agent.py Created Date: Tuesday, September 19th 2023, 6:52:25 pm Author: garylin2099 -''' +""" +import asyncio import re import subprocess -import asyncio import fire from metagpt.actions import Action +from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message -from metagpt.logs import logger + class SimpleWriteCode(Action): - PROMPT_TEMPLATE = """ Write a python function that can {instruction} and provide two runnnable test cases. Return ```python your_code_here ``` with NO other texts, @@ -35,7 +35,6 @@ class SimpleWriteCode(Action): super().__init__(name, context, llm) async def run(self, instruction: str): - prompt = self.PROMPT_TEMPLATE.format(instruction=instruction) rsp = await self._aask(prompt) @@ -46,11 +45,12 @@ class SimpleWriteCode(Action): @staticmethod def parse_code(rsp): - pattern = r'```python(.*)```' + pattern = r"```python(.*)```" match = re.search(pattern, rsp, re.DOTALL) code_text = match.group(1) if match else rsp return code_text + class SimpleRunCode(Action): def __init__(self, name="SimpleRunCode", context=None, llm=None): super().__init__(name, context, llm) @@ -61,6 +61,7 @@ class SimpleRunCode(Action): logger.info(f"{code_result=}") return code_result + class SimpleCoder(Role): def __init__( self, @@ -75,7 +76,7 @@ class SimpleCoder(Role): logger.info(f"{self._setting}: ready to {self._rc.todo}") todo = self._rc.todo - msg = self._rc.memory.get()[-1] # retrieve the latest memory + msg = self._rc.memory.get()[-1] # retrieve the latest memory instruction = msg.content code_text = await SimpleWriteCode().run(instruction) @@ -83,6 +84,7 @@ class SimpleCoder(Role): return msg + class RunnableCoder(Role): def __init__( self, @@ -128,6 +130,7 @@ class RunnableCoder(Role): await self._act() return Message(content="All job done", role=self.profile) + def main(msg="write a function that calculates the sum of a list"): # role = SimpleCoder() role = RunnableCoder() @@ -135,5 +138,6 @@ def main(msg="write a function that calculates the sum of a list"): result = asyncio.run(role.run(msg)) logger.info(result) -if __name__ == '__main__': + +if __name__ == "__main__": fire.Fire(main) diff --git a/examples/debate.py b/examples/debate.py index 05db28070..54da73cca 100644 --- a/examples/debate.py +++ b/examples/debate.py @@ -1,17 +1,19 @@ -''' +""" Filename: MetaGPT/examples/debate.py Created Date: Tuesday, September 19th 2023, 6:52:25 pm Author: garylin2099 -''' +""" import asyncio import platform + import fire -from metagpt.software_company import SoftwareCompany from metagpt.actions import Action, BossRequirement +from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message -from metagpt.logs import logger +from metagpt.software_company import SoftwareCompany + class ShoutOut(Action): """Action: Shout out loudly in a debate (quarrel)""" @@ -31,7 +33,6 @@ class ShoutOut(Action): super().__init__(name, context, llm) async def run(self, context: str, name: str, opponent_name: str): - prompt = self.PROMPT_TEMPLATE.format(context=context, name=name, opponent_name=opponent_name) # logger.info(prompt) @@ -39,6 +40,7 @@ class ShoutOut(Action): return rsp + class Trump(Role): def __init__( self, @@ -55,7 +57,7 @@ class Trump(Role): async def _observe(self) -> int: await super()._observe() # accept messages sent (from opponent) to self, disregard own messages from the last round - self._rc.news = [msg for msg in self._rc.news if msg.send_to == self.name] + self._rc.news = [msg for msg in self._rc.news if msg.send_to == self.name] return len(self._rc.news) async def _act(self) -> Message: @@ -79,6 +81,7 @@ class Trump(Role): return msg + class Biden(Role): def __init__( self, @@ -120,10 +123,12 @@ class Biden(Role): return msg -async def startup(idea: str, investment: float = 3.0, n_round: int = 5, - code_review: bool = False, run_tests: bool = False): + +async def startup( + idea: str, investment: float = 3.0, n_round: int = 5, code_review: bool = False, run_tests: bool = False +): """We reuse the startup paradigm for roles to interact with each other. - Now we run a startup of presidents and watch they quarrel. :) """ + Now we run a startup of presidents and watch they quarrel. :)""" company = SoftwareCompany() company.hire([Biden(), Trump()]) company.invest(investment) @@ -133,7 +138,7 @@ async def startup(idea: str, investment: float = 3.0, n_round: int = 5, def main(idea: str, investment: float = 3.0, n_round: int = 10): """ - :param idea: Debate topic, such as "Topic: The U.S. should commit more in climate change fighting" + :param idea: Debate topic, such as "Topic: The U.S. should commit more in climate change fighting" or "Trump: Climate change is a hoax" :param investment: contribute a certain dollar amount to watch the debate :param n_round: maximum rounds of the debate @@ -144,5 +149,5 @@ def main(idea: str, investment: float = 3.0, n_round: int = 10): asyncio.run(startup(idea, investment, n_round)) -if __name__ == '__main__': +if __name__ == "__main__": fire.Fire(main) diff --git a/examples/invoice_ocr.py b/examples/invoice_ocr.py index 11656ed52..a6e565772 100644 --- a/examples/invoice_ocr.py +++ b/examples/invoice_ocr.py @@ -19,19 +19,15 @@ async def main(): Path("../tests/data/invoices/invoice-1.pdf"), Path("../tests/data/invoices/invoice-2.png"), Path("../tests/data/invoices/invoice-3.jpg"), - Path("../tests/data/invoices/invoice-4.zip") + Path("../tests/data/invoices/invoice-4.zip"), ] # The absolute path of the file absolute_file_paths = [Path.cwd() / path for path in relative_paths] for path in absolute_file_paths: role = InvoiceOCRAssistant() - await role.run(Message( - content="Invoicing date", - instruct_content={"file_path": path} - )) + await role.run(Message(content="Invoicing date", instruct_content={"file_path": path})) -if __name__ == '__main__': +if __name__ == "__main__": asyncio.run(main()) - diff --git a/examples/llm_hello_world.py b/examples/llm_hello_world.py index 3ba03eea0..677098399 100644 --- a/examples/llm_hello_world.py +++ b/examples/llm_hello_world.py @@ -14,11 +14,11 @@ from metagpt.logs import logger async def main(): llm = LLM() claude = Claude() - logger.info(await claude.aask('你好,请进行自我介绍')) - logger.info(await llm.aask('hello world')) - logger.info(await llm.aask_batch(['hi', 'write python hello world.'])) + logger.info(await claude.aask("你好,请进行自我介绍")) + logger.info(await llm.aask("hello world")) + logger.info(await llm.aask_batch(["hi", "write python hello world."])) - hello_msg = [{'role': 'user', 'content': 'count from 1 to 10. split by newline.'}] + hello_msg = [{"role": "user", "content": "count from 1 to 10. split by newline."}] logger.info(await llm.acompletion(hello_msg)) logger.info(await llm.acompletion_batch([hello_msg])) logger.info(await llm.acompletion_batch_text([hello_msg])) @@ -27,5 +27,5 @@ async def main(): await llm.acompletion_text(hello_msg, stream=True) -if __name__ == '__main__': +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/research.py b/examples/research.py index 344f8d0e9..5c371cdd2 100644 --- a/examples/research.py +++ b/examples/research.py @@ -12,5 +12,5 @@ async def main(): print(f"save report to {RESEARCH_PATH / f'{topic}.md'}.") -if __name__ == '__main__': +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/search_google.py b/examples/search_google.py index 9e9521b9c..73d04bf87 100644 --- a/examples/search_google.py +++ b/examples/search_google.py @@ -15,5 +15,5 @@ async def main(): await Searcher().run("What are some good sun protection products?") -if __name__ == '__main__': +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/search_kb.py b/examples/search_kb.py index b6f7d87a0..0b5d59385 100644 --- a/examples/search_kb.py +++ b/examples/search_kb.py @@ -12,7 +12,7 @@ from metagpt.roles import Sales async def search(): - store = FaissStore(DATA_PATH / 'example.json') + store = FaissStore(DATA_PATH / "example.json") role = Sales(profile="Sales", store=store) queries = ["Which facial cleanser is good for oily skin?", "Is L'Oreal good to use?"] @@ -22,5 +22,5 @@ async def search(): logger.info(result) -if __name__ == '__main__': +if __name__ == "__main__": asyncio.run(search()) diff --git a/examples/search_with_specific_engine.py b/examples/search_with_specific_engine.py index 7cc431cd4..334a7821f 100644 --- a/examples/search_with_specific_engine.py +++ b/examples/search_with_specific_engine.py @@ -6,11 +6,12 @@ from metagpt.tools import SearchEngineType async def main(): # Serper API - #await Searcher(engine = SearchEngineType.SERPER_GOOGLE).run(["What are some good sun protection products?","What are some of the best beaches?"]) + # await Searcher(engine = SearchEngineType.SERPER_GOOGLE).run(["What are some good sun protection products?","What are some of the best beaches?"]) # SerpAPI - #await Searcher(engine=SearchEngineType.SERPAPI_GOOGLE).run("What are the best ski brands for skiers?") + # await Searcher(engine=SearchEngineType.SERPAPI_GOOGLE).run("What are the best ski brands for skiers?") # Google API await Searcher(engine=SearchEngineType.DIRECT_GOOGLE).run("What are the most interesting human facts?") -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/use_off_the_shelf_agent.py b/examples/use_off_the_shelf_agent.py index 2e10068bd..4445a6c62 100644 --- a/examples/use_off_the_shelf_agent.py +++ b/examples/use_off_the_shelf_agent.py @@ -1,12 +1,13 @@ -''' +""" Filename: MetaGPT/examples/use_off_the_shelf_agent.py Created Date: Tuesday, September 19th 2023, 6:52:25 pm Author: garylin2099 -''' +""" import asyncio -from metagpt.roles.product_manager import ProductManager from metagpt.logs import logger +from metagpt.roles.product_manager import ProductManager + async def main(): msg = "Write a PRD for a snake game" @@ -14,5 +15,6 @@ async def main(): result = await role.run(msg) logger.info(result.content[:100]) -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/write_tutorial.py b/examples/write_tutorial.py index 71ece5527..0dba3cdb7 100644 --- a/examples/write_tutorial.py +++ b/examples/write_tutorial.py @@ -16,6 +16,5 @@ async def main(): await role.run(topic) -if __name__ == '__main__': +if __name__ == "__main__": asyncio.run(main()) - diff --git a/metagpt/actions/action_output.py b/metagpt/actions/action_output.py index ea7f4fb80..25326d43b 100644 --- a/metagpt/actions/action_output.py +++ b/metagpt/actions/action_output.py @@ -23,10 +23,10 @@ class ActionOutput: def create_model_class(cls, class_name: str, mapping: Dict[str, Type]): new_class = create_model(class_name, **mapping) - @validator('*', allow_reuse=True) + @validator("*", allow_reuse=True) def check_name(v, field): if field.name not in mapping.keys(): - raise ValueError(f'Unrecognized block: {field.name}') + raise ValueError(f"Unrecognized block: {field.name}") return v @root_validator(pre=True, allow_reuse=True) @@ -34,10 +34,9 @@ class ActionOutput: required_fields = set(mapping.keys()) missing_fields = required_fields - set(values.keys()) if missing_fields: - raise ValueError(f'Missing fields: {missing_fields}') + raise ValueError(f"Missing fields: {missing_fields}") return values new_class.__validator_check_name = classmethod(check_name) new_class.__root_validator_check_missing_fields = classmethod(check_missing_fields) return new_class - \ No newline at end of file diff --git a/metagpt/actions/add_requirement.py b/metagpt/actions/add_requirement.py index 7dc09d062..16e14b3a4 100644 --- a/metagpt/actions/add_requirement.py +++ b/metagpt/actions/add_requirement.py @@ -10,5 +10,6 @@ from metagpt.actions import Action class BossRequirement(Action): """Boss Requirement without any implementation details""" + async def run(self, *args, **kwargs): raise NotImplementedError diff --git a/metagpt/actions/azure_tts.py b/metagpt/actions/azure_tts.py index c13a4750d..daa3f6892 100644 --- a/metagpt/actions/azure_tts.py +++ b/metagpt/actions/azure_tts.py @@ -18,16 +18,13 @@ class AzureTTS(Action): # Parameters reference: https://learn.microsoft.com/zh-cn/azure/cognitive-services/speech-service/language-support?tabs=tts#voice-styles-and-roles def synthesize_speech(self, lang, voice, role, text, output_file): - subscription_key = self.config.get('AZURE_TTS_SUBSCRIPTION_KEY') - region = self.config.get('AZURE_TTS_REGION') - speech_config = SpeechConfig( - subscription=subscription_key, region=region) + subscription_key = self.config.get("AZURE_TTS_SUBSCRIPTION_KEY") + region = self.config.get("AZURE_TTS_REGION") + speech_config = SpeechConfig(subscription=subscription_key, region=region) speech_config.speech_synthesis_voice_name = voice audio_config = AudioConfig(filename=output_file) - synthesizer = SpeechSynthesizer( - speech_config=speech_config, - audio_config=audio_config) + synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=audio_config) # if voice=="zh-CN-YunxiNeural": ssml_string = f""" @@ -45,9 +42,4 @@ class AzureTTS(Action): if __name__ == "__main__": azure_tts = AzureTTS("azure_tts") - azure_tts.synthesize_speech( - "zh-CN", - "zh-CN-YunxiNeural", - "Boy", - "Hello, I am Kaka", - "output.wav") + azure_tts.synthesize_speech("zh-CN", "zh-CN-YunxiNeural", "Boy", "Hello, I am Kaka", "output.wav") diff --git a/metagpt/actions/clone_function.py b/metagpt/actions/clone_function.py index cf7d22f04..1447e8dbf 100644 --- a/metagpt/actions/clone_function.py +++ b/metagpt/actions/clone_function.py @@ -1,5 +1,5 @@ -from pathlib import Path import traceback +from pathlib import Path from metagpt.actions.write_code import WriteCode from metagpt.logs import logger @@ -42,7 +42,7 @@ class CloneFunction(WriteCode): prompt = CLONE_PROMPT.format(source_code=source_code, template_func=template_func) logger.info(f"query for CloneFunction: \n {prompt}") code = await self.write_code(prompt) - logger.info(f'CloneFunction code is \n {highlight(code)}') + logger.info(f"CloneFunction code is \n {highlight(code)}") return code @@ -61,5 +61,5 @@ def run_function_script(code_script_path: str, func_name: str, *args, **kwargs): """Run function code from script.""" if isinstance(code_script_path, str): code_path = Path(code_script_path) - code = code_path.read_text(encoding='utf-8') + code = code_path.read_text(encoding="utf-8") return run_function_code(code, func_name, *args, **kwargs) diff --git a/metagpt/actions/debug_error.py b/metagpt/actions/debug_error.py index d69a22dba..304b1bc3e 100644 --- a/metagpt/actions/debug_error.py +++ b/metagpt/actions/debug_error.py @@ -7,8 +7,8 @@ """ import re -from metagpt.logs import logger from metagpt.actions.action import Action +from metagpt.logs import logger from metagpt.utils.common import CodeParser PROMPT_TEMPLATE = """ @@ -24,6 +24,8 @@ The message is as follows: Now you should start rewriting the code: ## file name of the code to rewrite: Write code with triple quoto. Do your best to implement THIS IN ONLY ONE FILE. """ + + class DebugError(Action): def __init__(self, name="DebugError", context=None, llm=None): super().__init__(name, context, llm) @@ -33,17 +35,17 @@ class DebugError(Action): # f"\n\n{error}\n\nPlease try to fix the error in this code." # fixed_code = await self._aask(prompt) # return fixed_code - + async def run(self, context): if "PASS" in context: return "", "the original code works fine, no need to debug" - + file_name = re.search("## File To Rewrite:\s*(.+\\.py)", context).group(1) logger.info(f"Debug and rewrite {file_name}") prompt = PROMPT_TEMPLATE.format(context=context) - + rsp = await self._aask(prompt) code = CodeParser.parse_code(block="", text=rsp) diff --git a/metagpt/actions/design_api_review.py b/metagpt/actions/design_api_review.py index 9bb822a62..7f25bb9a3 100644 --- a/metagpt/actions/design_api_review.py +++ b/metagpt/actions/design_api_review.py @@ -13,10 +13,11 @@ class DesignReview(Action): super().__init__(name, context, llm) async def run(self, prd, api_design): - prompt = f"Here is the Product Requirement Document (PRD):\n\n{prd}\n\nHere is the list of APIs designed " \ - f"based on this PRD:\n\n{api_design}\n\nPlease review whether this API design meets the requirements" \ - f" of the PRD, and whether it complies with good design practices." + prompt = ( + f"Here is the Product Requirement Document (PRD):\n\n{prd}\n\nHere is the list of APIs designed " + f"based on this PRD:\n\n{api_design}\n\nPlease review whether this API design meets the requirements" + f" of the PRD, and whether it complies with good design practices." + ) api_review = await self._aask(prompt) return api_review - \ No newline at end of file diff --git a/metagpt/actions/design_filenames.py b/metagpt/actions/design_filenames.py index 29400e950..ffa171d7b 100644 --- a/metagpt/actions/design_filenames.py +++ b/metagpt/actions/design_filenames.py @@ -17,8 +17,10 @@ Do not add any other explanations, just return a Python string list.""" class DesignFilenames(Action): def __init__(self, name, context=None, llm=None): super().__init__(name, context, llm) - self.desc = "Based on the PRD, consider system design, and carry out the basic design of the corresponding " \ - "APIs, data structures, and database tables. Please give your design, feedback clearly and in detail." + self.desc = ( + "Based on the PRD, consider system design, and carry out the basic design of the corresponding " + "APIs, data structures, and database tables. Please give your design, feedback clearly and in detail." + ) async def run(self, prd): prompt = f"The following is the Product Requirement Document (PRD):\n\n{prd}\n\n{PROMPT}" @@ -26,4 +28,3 @@ class DesignFilenames(Action): logger.debug(prompt) logger.debug(design_filenames) return design_filenames - \ No newline at end of file diff --git a/metagpt/actions/detail_mining.py b/metagpt/actions/detail_mining.py index e29d6911b..5afcf52c6 100644 --- a/metagpt/actions/detail_mining.py +++ b/metagpt/actions/detail_mining.py @@ -6,7 +6,6 @@ @File : detail_mining.py """ from metagpt.actions import Action, ActionOutput -from metagpt.logs import logger PROMPT_TEMPLATE = """ ##TOPIC @@ -41,8 +40,8 @@ OUTPUT_MAPPING = { class DetailMining(Action): - """This class allows LLM to further mine noteworthy details based on specific "##TOPIC"(discussion topic) and "##RECORD" (discussion records), thereby deepening the discussion. - """ + """This class allows LLM to further mine noteworthy details based on specific "##TOPIC"(discussion topic) and "##RECORD" (discussion records), thereby deepening the discussion.""" + def __init__(self, name="", context=None, llm=None): super().__init__(name, context, llm) diff --git a/metagpt/actions/invoice_ocr.py b/metagpt/actions/invoice_ocr.py index b37aa6885..dcf537a58 100644 --- a/metagpt/actions/invoice_ocr.py +++ b/metagpt/actions/invoice_ocr.py @@ -10,8 +10,8 @@ import os import zipfile -from pathlib import Path from datetime import datetime +from pathlib import Path import pandas as pd from paddleocr import PaddleOCR @@ -19,7 +19,10 @@ from paddleocr import PaddleOCR from metagpt.actions import Action from metagpt.const import INVOICE_OCR_TABLE_PATH from metagpt.logs import logger -from metagpt.prompts.invoice_ocr import EXTRACT_OCR_MAIN_INFO_PROMPT, REPLY_OCR_QUESTION_PROMPT +from metagpt.prompts.invoice_ocr import ( + EXTRACT_OCR_MAIN_INFO_PROMPT, + REPLY_OCR_QUESTION_PROMPT, +) from metagpt.utils.common import OutputParser from metagpt.utils.file import File @@ -183,4 +186,3 @@ class ReplyQuestion(Action): prompt = REPLY_OCR_QUESTION_PROMPT.format(query=query, ocr_result=ocr_result, language=self.language) resp = await self._aask(prompt=prompt) return resp - diff --git a/metagpt/actions/prepare_interview.py b/metagpt/actions/prepare_interview.py index 5db3a9f37..b2704616e 100644 --- a/metagpt/actions/prepare_interview.py +++ b/metagpt/actions/prepare_interview.py @@ -38,4 +38,3 @@ class PrepareInterview(Action): prompt = PROMPT_TEMPLATE.format(context=context) question_list = await self._aask_v1(prompt) return question_list - diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py index 49a981e86..d7a2a7e38 100644 --- a/metagpt/actions/research.py +++ b/metagpt/actions/research.py @@ -3,7 +3,6 @@ from __future__ import annotations import asyncio -import json from typing import Callable from pydantic import parse_obj_as @@ -49,7 +48,7 @@ based on the link credibility. If two results have equal credibility, prioritize ranked results' indices in JSON format, like [0, 1, 3, 4, ...], without including other words. """ -WEB_BROWSE_AND_SUMMARIZE_PROMPT = '''### Requirements +WEB_BROWSE_AND_SUMMARIZE_PROMPT = """### Requirements 1. Utilize the text in the "Reference Information" section to respond to the question "{query}". 2. If the question cannot be directly answered using the text, but the text is related to the research topic, please provide \ a comprehensive summary of the text. @@ -58,10 +57,10 @@ a comprehensive summary of the text. ### Reference Information {content} -''' +""" -CONDUCT_RESEARCH_PROMPT = '''### Reference Information +CONDUCT_RESEARCH_PROMPT = """### Reference Information {content} ### Requirements @@ -73,11 +72,12 @@ above. The report must meet the following requirements: - Present data and findings in an intuitive manner, utilizing feature comparative tables, if applicable. - The report should have a minimum word count of 2,000 and be formatted with Markdown syntax following APA style guidelines. - Include all source URLs in APA format at the end of the report. -''' +""" class CollectLinks(Action): """Action class to collect links from a search engine.""" + def __init__( self, name: str = "", @@ -114,19 +114,24 @@ class CollectLinks(Action): keywords = OutputParser.extract_struct(keywords, list) keywords = parse_obj_as(list[str], keywords) except Exception as e: - logger.exception(f"fail to get keywords related to the research topic \"{topic}\" for {e}") + logger.exception(f'fail to get keywords related to the research topic "{topic}" for {e}') keywords = [topic] results = await asyncio.gather(*(self.search_engine.run(i, as_string=False) for i in keywords)) def gen_msg(): while True: - search_results = "\n".join(f"#### Keyword: {i}\n Search Result: {j}\n" for (i, j) in zip(keywords, results)) - prompt = SUMMARIZE_SEARCH_PROMPT.format(decomposition_nums=decomposition_nums, search_results=search_results) + search_results = "\n".join( + f"#### Keyword: {i}\n Search Result: {j}\n" for (i, j) in zip(keywords, results) + ) + prompt = SUMMARIZE_SEARCH_PROMPT.format( + decomposition_nums=decomposition_nums, search_results=search_results + ) yield prompt remove = max(results, key=len) remove.pop() if len(remove) == 0: break + prompt = reduce_message_length(gen_msg(), self.llm.model, system_text, CONFIG.max_tokens_rsp) logger.debug(prompt) queries = await self._aask(prompt, [system_text]) @@ -172,6 +177,7 @@ class CollectLinks(Action): class WebBrowseAndSummarize(Action): """Action class to explore the web and provide summaries of articles and webpages.""" + def __init__( self, *args, @@ -214,7 +220,9 @@ class WebBrowseAndSummarize(Action): for u, content in zip([url, *urls], contents): content = content.inner_text chunk_summaries = [] - for prompt in generate_prompt_chunk(content, prompt_template, self.llm.model, system_text, CONFIG.max_tokens_rsp): + for prompt in generate_prompt_chunk( + content, prompt_template, self.llm.model, system_text, CONFIG.max_tokens_rsp + ): logger.debug(prompt) summary = await self._aask(prompt, [system_text]) if summary == "Not relevant.": @@ -238,6 +246,7 @@ class WebBrowseAndSummarize(Action): class ConductResearch(Action): """Action class to conduct research and generate a research report.""" + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if CONFIG.model_for_researcher_report: diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index 069f2a977..5e4cdaea0 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -140,4 +140,3 @@ class SearchAndSummarize(Action): logger.debug(prompt) logger.debug(result) return result - \ No newline at end of file diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index c000805c5..a922d3694 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -5,13 +5,14 @@ @Author : alexanderwu @File : write_code.py """ +from tenacity import retry, stop_after_attempt, wait_fixed + from metagpt.actions import WriteDesign from metagpt.actions.action import Action from metagpt.const import WORKSPACE_ROOT from metagpt.logs import logger from metagpt.schema import Message from metagpt.utils.common import CodeParser -from tenacity import retry, stop_after_attempt, wait_fixed PROMPT_TEMPLATE = """ NOTICE @@ -74,9 +75,8 @@ class WriteCode(Action): async def run(self, context, filename): prompt = PROMPT_TEMPLATE.format(context=context, filename=filename) - logger.info(f'Writing {filename}..') + logger.info(f"Writing {filename}..") code = await self.write_code(prompt) # code_rsp = await self._aask_v1(prompt, "code_rsp", OUTPUT_MAPPING) # self._save(context, filename, code) return code - \ No newline at end of file diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index 4ff4d6cf6..76adca255 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -6,11 +6,12 @@ @File : write_code_review.py """ +from tenacity import retry, stop_after_attempt, wait_fixed + from metagpt.actions.action import Action from metagpt.logs import logger from metagpt.schema import Message from metagpt.utils.common import CodeParser -from tenacity import retry, stop_after_attempt, wait_fixed PROMPT_TEMPLATE = """ NOTICE @@ -74,9 +75,8 @@ class WriteCodeReview(Action): async def run(self, context, code, filename): format_example = FORMAT_EXAMPLE.format(filename=filename) prompt = PROMPT_TEMPLATE.format(context=context, code=code, filename=filename, format_example=format_example) - logger.info(f'Code review {filename}..') + logger.info(f"Code review {filename}..") code = await self.write_code(prompt) # code_rsp = await self._aask_v1(prompt, "code_rsp", OUTPUT_MAPPING) # self._save(context, filename, code) return code - \ No newline at end of file diff --git a/metagpt/actions/write_docstring.py b/metagpt/actions/write_docstring.py index 5c7815793..dd3312bd5 100644 --- a/metagpt/actions/write_docstring.py +++ b/metagpt/actions/write_docstring.py @@ -28,7 +28,7 @@ from metagpt.actions.action import Action from metagpt.utils.common import OutputParser from metagpt.utils.pycst import merge_docstring -PYTHON_DOCSTRING_SYSTEM = '''### Requirements +PYTHON_DOCSTRING_SYSTEM = """### Requirements 1. Add docstrings to the given code following the {style} style. 2. Replace the function body with an Ellipsis object(...) to reduce output. 3. If the types are already annotated, there is no need to include them in the docstring. @@ -48,7 +48,7 @@ class ExampleError(Exception): ```python {example} ``` -''' +""" # https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html @@ -162,7 +162,8 @@ class WriteDocstring(Action): self.desc = "Write docstring for code." async def run( - self, code: str, + self, + code: str, system_text: str = PYTHON_DOCSTRING_SYSTEM, style: Literal["google", "numpy", "sphinx"] = "google", ) -> str: diff --git a/metagpt/actions/write_prd_review.py b/metagpt/actions/write_prd_review.py index 5c922d3bc..5ff9624c5 100644 --- a/metagpt/actions/write_prd_review.py +++ b/metagpt/actions/write_prd_review.py @@ -25,4 +25,3 @@ class WritePRDReview(Action): prompt = self.prd_review_prompt_template.format(prd=self.prd) review = await self._aask(prompt) return review - \ No newline at end of file diff --git a/metagpt/actions/write_tutorial.py b/metagpt/actions/write_tutorial.py index 23e3560e8..d41915de3 100644 --- a/metagpt/actions/write_tutorial.py +++ b/metagpt/actions/write_tutorial.py @@ -10,7 +10,7 @@ from typing import Dict from metagpt.actions import Action -from metagpt.prompts.tutorial_assistant import DIRECTORY_PROMPT, CONTENT_PROMPT +from metagpt.prompts.tutorial_assistant import CONTENT_PROMPT, DIRECTORY_PROMPT from metagpt.utils.common import OutputParser @@ -65,4 +65,3 @@ class WriteContent(Action): """ prompt = CONTENT_PROMPT.format(topic=topic, language=self.language, directory=self.directory) return await self._aask(prompt=prompt) - diff --git a/metagpt/config.py b/metagpt/config.py index 27455d38d..d93640c1b 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -46,7 +46,7 @@ class Config(metaclass=Singleton): self.openai_api_key = self._get("OPENAI_API_KEY") self.anthropic_api_key = self._get("Anthropic_API_KEY") if (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) and ( - not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key + not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key ): raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY first") self.openai_api_base = self._get("OPENAI_API_BASE") diff --git a/metagpt/document_store/base_store.py b/metagpt/document_store/base_store.py index 5d7015e8b..7d102e00b 100644 --- a/metagpt/document_store/base_store.py +++ b/metagpt/document_store/base_store.py @@ -41,7 +41,7 @@ class LocalStore(BaseStore, ABC): self.store = self.write() def _get_index_and_store_fname(self): - fname = self.raw_data.name.split('.')[0] + fname = self.raw_data.name.split(".")[0] index_file = self.cache_dir / f"{fname}.index" store_file = self.cache_dir / f"{fname}.pkl" return index_file, store_file @@ -53,4 +53,3 @@ class LocalStore(BaseStore, ABC): @abstractmethod def _write(self, docs, metadatas): raise NotImplementedError - \ No newline at end of file diff --git a/metagpt/document_store/chromadb_store.py b/metagpt/document_store/chromadb_store.py index d2ecc05f6..d7344d41b 100644 --- a/metagpt/document_store/chromadb_store.py +++ b/metagpt/document_store/chromadb_store.py @@ -10,6 +10,7 @@ import chromadb class ChromaStore: """If inherited from BaseStore, or importing other modules from metagpt, a Python exception occurs, which is strange.""" + def __init__(self, name): client = chromadb.Client() collection = client.create_collection(name) @@ -22,7 +23,7 @@ class ChromaStore: query_texts=[query], n_results=n_results, where=metadata_filter, # optional filter - where_document=document_filter # optional filter + where_document=document_filter, # optional filter ) return results diff --git a/metagpt/document_store/document.py b/metagpt/document_store/document.py index e4b9473c7..c59056312 100644 --- a/metagpt/document_store/document.py +++ b/metagpt/document_store/document.py @@ -24,20 +24,20 @@ def validate_cols(content_col: str, df: pd.DataFrame): def read_data(data_path: Path): suffix = data_path.suffix - if '.xlsx' == suffix: + if ".xlsx" == suffix: data = pd.read_excel(data_path) - elif '.csv' == suffix: + elif ".csv" == suffix: data = pd.read_csv(data_path) - elif '.json' == suffix: + elif ".json" == suffix: data = pd.read_json(data_path) - elif suffix in ('.docx', '.doc'): - data = UnstructuredWordDocumentLoader(str(data_path), mode='elements').load() - elif '.txt' == suffix: + elif suffix in (".docx", ".doc"): + data = UnstructuredWordDocumentLoader(str(data_path), mode="elements").load() + elif ".txt" == suffix: data = TextLoader(str(data_path)).load() - text_splitter = CharacterTextSplitter(separator='\n', chunk_size=256, chunk_overlap=0) + text_splitter = CharacterTextSplitter(separator="\n", chunk_size=256, chunk_overlap=0) texts = text_splitter.split_documents(data) data = texts - elif '.pdf' == suffix: + elif ".pdf" == suffix: data = UnstructuredPDFLoader(str(data_path), mode="elements").load() else: raise NotImplementedError @@ -45,8 +45,7 @@ def read_data(data_path: Path): class Document: - - def __init__(self, data_path, content_col='content', meta_col='metadata'): + def __init__(self, data_path, content_col="content", meta_col="metadata"): self.data = read_data(data_path) if isinstance(self.data, pd.DataFrame): validate_cols(content_col, self.data) @@ -79,4 +78,3 @@ class Document: return self._get_docs_and_metadatas_by_langchain() else: raise NotImplementedError - \ No newline at end of file diff --git a/metagpt/document_store/faiss_store.py b/metagpt/document_store/faiss_store.py index dd450010d..8ff904cdd 100644 --- a/metagpt/document_store/faiss_store.py +++ b/metagpt/document_store/faiss_store.py @@ -20,7 +20,7 @@ from metagpt.logs import logger class FaissStore(LocalStore): - def __init__(self, raw_data: Path, cache_dir=None, meta_col='source', content_col='output'): + def __init__(self, raw_data: Path, cache_dir=None, meta_col="source", content_col="output"): self.meta_col = meta_col self.content_col = content_col super().__init__(raw_data, cache_dir) @@ -50,7 +50,7 @@ class FaissStore(LocalStore): pickle.dump(store, f) store.index = index - def search(self, query, expand_cols=False, sep='\n', *args, k=5, **kwargs): + def search(self, query, expand_cols=False, sep="\n", *args, k=5, **kwargs): rsp = self.store.similarity_search(query, k=k, **kwargs) logger.debug(rsp) if expand_cols: @@ -78,8 +78,8 @@ class FaissStore(LocalStore): raise NotImplementedError -if __name__ == '__main__': - faiss_store = FaissStore(DATA_PATH / 'qcs/qcs_4w.json') - logger.info(faiss_store.search('Oily Skin Facial Cleanser')) - faiss_store.add([f'Oily Skin Facial Cleanser-{i}' for i in range(3)]) - logger.info(faiss_store.search('Oily Skin Facial Cleanser')) +if __name__ == "__main__": + faiss_store = FaissStore(DATA_PATH / "qcs/qcs_4w.json") + logger.info(faiss_store.search("Oily Skin Facial Cleanser")) + faiss_store.add([f"Oily Skin Facial Cleanser-{i}" for i in range(3)]) + logger.info(faiss_store.search("Oily Skin Facial Cleanser")) diff --git a/metagpt/document_store/milvus_store.py b/metagpt/document_store/milvus_store.py index 77a8ec141..fcfc59d79 100644 --- a/metagpt/document_store/milvus_store.py +++ b/metagpt/document_store/milvus_store.py @@ -12,12 +12,7 @@ from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connec from metagpt.document_store.base_store import BaseStore -type_mapping = { - int: DataType.INT64, - str: DataType.VARCHAR, - float: DataType.DOUBLE, - np.ndarray: DataType.FLOAT_VECTOR -} +type_mapping = {int: DataType.INT64, str: DataType.VARCHAR, float: DataType.DOUBLE, np.ndarray: DataType.FLOAT_VECTOR} def columns_to_milvus_schema(columns: dict, primary_col_name: str = "", desc: str = ""): @@ -52,17 +47,11 @@ class MilvusStore(BaseStore): self.collection = None def _create_collection(self, name, schema): - collection = Collection( - name=name, - schema=schema, - using='default', - shards_num=2, - consistency_level="Strong" - ) + collection = Collection(name=name, schema=schema, using="default", shards_num=2, consistency_level="Strong") return collection def create_collection(self, name, columns): - schema = columns_to_milvus_schema(columns, 'idx') + schema = columns_to_milvus_schema(columns, "idx") self.collection = self._create_collection(name, schema) return self.collection @@ -72,7 +61,7 @@ class MilvusStore(BaseStore): def load_collection(self): self.collection.load() - def build_index(self, field='emb'): + def build_index(self, field="emb"): self.collection.create_index(field, {"index_type": "FLAT", "metric_type": "L2", "params": {}}) def search(self, query: list[list[float]], *args, **kwargs): @@ -85,11 +74,11 @@ class MilvusStore(BaseStore): search_params = {"metric_type": "L2", "params": {"nprobe": 10}} results = self.collection.search( data=query, - anns_field=kwargs.get('field', 'emb'), + anns_field=kwargs.get("field", "emb"), param=search_params, limit=10, expr=None, - consistency_level="Strong" + consistency_level="Strong", ) # FIXME: results contain id, but to get the actual value from the id, we still need to call the query interface return results diff --git a/metagpt/document_store/qdrant_store.py b/metagpt/document_store/qdrant_store.py index 98b82cf87..4e9637aa7 100644 --- a/metagpt/document_store/qdrant_store.py +++ b/metagpt/document_store/qdrant_store.py @@ -10,13 +10,14 @@ from metagpt.document_store.base_store import BaseStore @dataclass class QdrantConnection: """ - Args: - url: qdrant url - host: qdrant host - port: qdrant port - memory: qdrant service use memory mode - api_key: qdrant cloud api_key - """ + Args: + url: qdrant url + host: qdrant host + port: qdrant port + memory: qdrant service use memory mode + api_key: qdrant cloud api_key + """ + url: str = None host: str = None port: int = None @@ -31,9 +32,7 @@ class QdrantStore(BaseStore): elif connect.url: self.client = QdrantClient(url=connect.url, api_key=connect.api_key) elif connect.host and connect.port: - self.client = QdrantClient( - host=connect.host, port=connect.port, api_key=connect.api_key - ) + self.client = QdrantClient(host=connect.host, port=connect.port, api_key=connect.api_key) else: raise Exception("please check QdrantConnection.") @@ -58,15 +57,11 @@ class QdrantStore(BaseStore): try: self.client.get_collection(collection_name) if force_recreate: - res = self.client.recreate_collection( - collection_name, vectors_config=vectors_config, **kwargs - ) + res = self.client.recreate_collection(collection_name, vectors_config=vectors_config, **kwargs) return res return True except: # noqa: E722 - return self.client.recreate_collection( - collection_name, vectors_config=vectors_config, **kwargs - ) + return self.client.recreate_collection(collection_name, vectors_config=vectors_config, **kwargs) def has_collection(self, collection_name: str): try: diff --git a/metagpt/environment.py b/metagpt/environment.py index 24e6ada2f..2e2aa152a 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -17,34 +17,34 @@ from metagpt.schema import Message class Environment(BaseModel): """环境,承载一批角色,角色可以向环境发布消息,可以被其他角色观察到 - Environment, hosting a batch of roles, roles can publish messages to the environment, and can be observed by other roles - + Environment, hosting a batch of roles, roles can publish messages to the environment, and can be observed by other roles + """ roles: dict[str, Role] = Field(default_factory=dict) memory: Memory = Field(default_factory=Memory) - history: str = Field(default='') + history: str = Field(default="") class Config: arbitrary_types_allowed = True def add_role(self, role: Role): """增加一个在当前环境的角色 - Add a role in the current environment + Add a role in the current environment """ role.set_env(self) self.roles[role.profile] = role def add_roles(self, roles: Iterable[Role]): """增加一批在当前环境的角色 - Add a batch of characters in the current environment + Add a batch of characters in the current environment """ for role in roles: self.add_role(role) def publish_message(self, message: Message): """向当前环境发布信息 - Post information to the current environment + Post information to the current environment """ # self.message_queue.put(message) self.memory.add(message) @@ -68,12 +68,12 @@ class Environment(BaseModel): def get_roles(self) -> dict[str, Role]: """获得环境内的所有角色 - Process all Role runs at once + Process all Role runs at once """ return self.roles def get_role(self, name: str) -> Role: """获得环境内的指定角色 - get all the environment roles + get all the environment roles """ return self.roles.get(name, None) diff --git a/metagpt/inspect_module.py b/metagpt/inspect_module.py index a89ac1c5e..48ceffc57 100644 --- a/metagpt/inspect_module.py +++ b/metagpt/inspect_module.py @@ -12,17 +12,17 @@ import metagpt # replace with your module def print_classes_and_functions(module): - """FIXME: NOT WORK.. """ + """FIXME: NOT WORK..""" for name, obj in inspect.getmembers(module): if inspect.isclass(obj): - print(f'Class: {name}') + print(f"Class: {name}") elif inspect.isfunction(obj): - print(f'Function: {name}') + print(f"Function: {name}") else: print(name) print(dir(module)) -if __name__ == '__main__': - print_classes_and_functions(metagpt) \ No newline at end of file +if __name__ == "__main__": + print_classes_and_functions(metagpt) diff --git a/metagpt/llm.py b/metagpt/llm.py index e6f815950..410f3dcb5 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -12,8 +12,9 @@ from metagpt.provider.openai_api import OpenAIGPTAPI as LLM DEFAULT_LLM = LLM() CLAUDE_LLM = Claude() + async def ai_func(prompt): """使用LLM进行QA - QA with LLMs - """ + QA with LLMs + """ return await DEFAULT_LLM.aask(prompt) diff --git a/metagpt/logs.py b/metagpt/logs.py index b2052e9b8..55d85312f 100644 --- a/metagpt/logs.py +++ b/metagpt/logs.py @@ -12,13 +12,15 @@ from loguru import logger as _logger from metagpt.const import PROJECT_ROOT + def define_log_level(print_level="INFO", logfile_level="DEBUG"): """调整日志级别到level之上 - Adjust the log level to above level + Adjust the log level to above level """ _logger.remove() _logger.add(sys.stderr, level=print_level) - _logger.add(PROJECT_ROOT / 'logs/log.txt', level=logfile_level) + _logger.add(PROJECT_ROOT / "logs/log.txt", level=logfile_level) return _logger + logger = define_log_level() diff --git a/metagpt/management/skill_manager.py b/metagpt/management/skill_manager.py index f967a0a94..b3181b64e 100644 --- a/metagpt/management/skill_manager.py +++ b/metagpt/management/skill_manager.py @@ -19,8 +19,8 @@ class SkillManager: def __init__(self): self._llm = LLM() - self._store = ChromaStore('skill_manager') - self._skills: dict[str: Skill] = {} + self._store = ChromaStore("skill_manager") + self._skills: dict[str:Skill] = {} def add_skill(self, skill: Skill): """ @@ -54,7 +54,7 @@ class SkillManager: :param desc: Skill description :return: Multiple skills """ - return self._store.search(desc, n_results=n_results)['ids'][0] + return self._store.search(desc, n_results=n_results)["ids"][0] def retrieve_skill_scored(self, desc: str, n_results: int = 2) -> dict: """ @@ -75,6 +75,6 @@ class SkillManager: logger.info(text) -if __name__ == '__main__': +if __name__ == "__main__": manager = SkillManager() manager.generate_skill_desc(Action()) diff --git a/metagpt/manager.py b/metagpt/manager.py index 9d238c621..d0b6b101c 100644 --- a/metagpt/manager.py +++ b/metagpt/manager.py @@ -18,7 +18,7 @@ class Manager: "Product Manager": "Architect", "Architect": "Engineer", "Engineer": "QA Engineer", - "QA Engineer": "Product Manager" + "QA Engineer": "Product Manager", } self.prompt_template = """ Given the following message: @@ -51,7 +51,7 @@ class Manager: # chosen_role_name = self.llm.ask(self.prompt_template.format(context)) # FIXME: 现在通过简单的字典决定流向,但之后还是应该有思考过程 - #The direction of flow is now determined by a simple dictionary, but there should still be a thought process afterwards + # The direction of flow is now determined by a simple dictionary, but there should still be a thought process afterwards next_role_profile = self.role_directions[message.role] # logger.debug(f"{next_role_profile}") for _, role in roles.items(): diff --git a/metagpt/memory/longterm_memory.py b/metagpt/memory/longterm_memory.py index f8abea5f3..e0b8e68c1 100644 --- a/metagpt/memory/longterm_memory.py +++ b/metagpt/memory/longterm_memory.py @@ -68,4 +68,3 @@ class LongTermMemory(Memory): def clear(self): super(LongTermMemory, self).clear() self.memory_storage.clean() - \ No newline at end of file diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index c818fa707..282f5fe33 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -85,4 +85,3 @@ class Memory: continue rsp += self.index[action] return rsp - \ No newline at end of file diff --git a/metagpt/memory/memory_storage.py b/metagpt/memory/memory_storage.py index 302d96aa7..a213f6d7a 100644 --- a/metagpt/memory/memory_storage.py +++ b/metagpt/memory/memory_storage.py @@ -2,16 +2,16 @@ # -*- coding: utf-8 -*- # @Desc : the implement of memory storage -from typing import List from pathlib import Path +from typing import List from langchain.vectorstores.faiss import FAISS from metagpt.const import DATA_PATH, MEM_TTL +from metagpt.document_store.faiss_store import FaissStore from metagpt.logs import logger from metagpt.schema import Message -from metagpt.utils.serialize import serialize_message, deserialize_message -from metagpt.document_store.faiss_store import FaissStore +from metagpt.utils.serialize import deserialize_message, serialize_message class MemoryStorage(FaissStore): @@ -34,7 +34,7 @@ class MemoryStorage(FaissStore): def recover_memory(self, role_id: str) -> List[Message]: self.role_id = role_id - self.role_mem_path = Path(DATA_PATH / f'role_mem/{self.role_id}/') + self.role_mem_path = Path(DATA_PATH / f"role_mem/{self.role_id}/") self.role_mem_path.mkdir(parents=True, exist_ok=True) self.store = self._load() @@ -51,18 +51,18 @@ class MemoryStorage(FaissStore): def _get_index_and_store_fname(self): if not self.role_mem_path: - logger.error(f'You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory') + logger.error(f"You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory") return None, None - index_fpath = Path(self.role_mem_path / f'{self.role_id}.index') - storage_fpath = Path(self.role_mem_path / f'{self.role_id}.pkl') + index_fpath = Path(self.role_mem_path / f"{self.role_id}.index") + storage_fpath = Path(self.role_mem_path / f"{self.role_id}.pkl") return index_fpath, storage_fpath def persist(self): super(MemoryStorage, self).persist() - logger.debug(f'Agent {self.role_id} persist memory into local') + logger.debug(f"Agent {self.role_id} persist memory into local") def add(self, message: Message) -> bool: - """ add message into memory storage""" + """add message into memory storage""" docs = [message.content] metadatas = [{"message_ser": serialize_message(message)}] if not self.store: @@ -79,10 +79,7 @@ class MemoryStorage(FaissStore): if not self.store: return [] - resp = self.store.similarity_search_with_score( - query=message.content, - k=k - ) + resp = self.store.similarity_search_with_score(query=message.content, k=k) # filter the result which score is smaller than the threshold filtered_resp = [] for item, score in resp: @@ -104,4 +101,3 @@ class MemoryStorage(FaissStore): self.store = None self._initialized = False - \ No newline at end of file diff --git a/metagpt/prompts/invoice_ocr.py b/metagpt/prompts/invoice_ocr.py index 52f628a5b..aa79651be 100644 --- a/metagpt/prompts/invoice_ocr.py +++ b/metagpt/prompts/invoice_ocr.py @@ -10,7 +10,9 @@ COMMON_PROMPT = "Now I will provide you with the OCR text recognition results for the invoice." -EXTRACT_OCR_MAIN_INFO_PROMPT = COMMON_PROMPT + """ +EXTRACT_OCR_MAIN_INFO_PROMPT = ( + COMMON_PROMPT + + """ Please extract the payee, city, total cost, and invoicing date of the invoice. The OCR data of the invoice are as follows: @@ -22,8 +24,11 @@ Mandatory restrictions are returned according to the following requirements: 2. The returned JSON dictionary must be returned in {language} 3. Mandatory requirement to output in JSON format: {{"收款人":"x","城市":"x","总费用/元":"","开票日期":""}}. """ +) -REPLY_OCR_QUESTION_PROMPT = COMMON_PROMPT + """ +REPLY_OCR_QUESTION_PROMPT = ( + COMMON_PROMPT + + """ Please answer the question: {query} The OCR data of the invoice are as follows: @@ -34,6 +39,6 @@ Mandatory restrictions are returned according to the following requirements: 2. Enforce restrictions on not returning OCR data sent to you. 3. Return with markdown syntax layout. """ +) INVOICE_OCR_SUCCESS = "Successfully completed OCR text recognition invoice." - diff --git a/metagpt/prompts/sales.py b/metagpt/prompts/sales.py index a44aacafe..30ef1ae02 100644 --- a/metagpt/prompts/sales.py +++ b/metagpt/prompts/sales.py @@ -54,10 +54,12 @@ Conversation history: {salesperson_name}: """ -conversation_stages = {'1' : "Introduction: Start the conversation by introducing yourself and your company. Be polite and respectful while keeping the tone of the conversation professional. Your greeting should be welcoming. Always clarify in your greeting the reason why you are contacting the prospect.", -'2': "Qualification: Qualify the prospect by confirming if they are the right person to talk to regarding your product/service. Ensure that they have the authority to make purchasing decisions.", -'3': "Value proposition: Briefly explain how your product/service can benefit the prospect. Focus on the unique selling points and value proposition of your product/service that sets it apart from competitors.", -'4': "Needs analysis: Ask open-ended questions to uncover the prospect's needs and pain points. Listen carefully to their responses and take notes.", -'5': "Solution presentation: Based on the prospect's needs, present your product/service as the solution that can address their pain points.", -'6': "Objection handling: Address any objections that the prospect may have regarding your product/service. Be prepared to provide evidence or testimonials to support your claims.", -'7': "Close: Ask for the sale by proposing a next step. This could be a demo, a trial or a meeting with decision-makers. Ensure to summarize what has been discussed and reiterate the benefits."} +conversation_stages = { + "1": "Introduction: Start the conversation by introducing yourself and your company. Be polite and respectful while keeping the tone of the conversation professional. Your greeting should be welcoming. Always clarify in your greeting the reason why you are contacting the prospect.", + "2": "Qualification: Qualify the prospect by confirming if they are the right person to talk to regarding your product/service. Ensure that they have the authority to make purchasing decisions.", + "3": "Value proposition: Briefly explain how your product/service can benefit the prospect. Focus on the unique selling points and value proposition of your product/service that sets it apart from competitors.", + "4": "Needs analysis: Ask open-ended questions to uncover the prospect's needs and pain points. Listen carefully to their responses and take notes.", + "5": "Solution presentation: Based on the prospect's needs, present your product/service as the solution that can address their pain points.", + "6": "Objection handling: Address any objections that the prospect may have regarding your product/service. Be prepared to provide evidence or testimonials to support your claims.", + "7": "Close: Ask for the sale by proposing a next step. This could be a demo, a trial or a meeting with decision-makers. Ensure to summarize what has been discussed and reiterate the benefits.", +} diff --git a/metagpt/prompts/tutorial_assistant.py b/metagpt/prompts/tutorial_assistant.py index d690aad83..3d4b6fa24 100644 --- a/metagpt/prompts/tutorial_assistant.py +++ b/metagpt/prompts/tutorial_assistant.py @@ -12,7 +12,9 @@ You are now a seasoned technical professional in the field of the internet. We need you to write a technical tutorial with the topic "{topic}". """ -DIRECTORY_PROMPT = COMMON_PROMPT + """ +DIRECTORY_PROMPT = ( + COMMON_PROMPT + + """ Please provide the specific table of contents for this tutorial, strictly following the following requirements: 1. The output must be strictly in the specified language, {language}. 2. Answer strictly in the dictionary format like {{"title": "xxx", "directory": [{{"dir 1": ["sub dir 1", "sub dir 2"]}}, {{"dir 2": ["sub dir 3", "sub dir 4"]}}]}}. @@ -20,8 +22,11 @@ Please provide the specific table of contents for this tutorial, strictly follow 4. Do not have extra spaces or line breaks. 5. Each directory title has practical significance. """ +) -CONTENT_PROMPT = COMMON_PROMPT + """ +CONTENT_PROMPT = ( + COMMON_PROMPT + + """ Now I will give you the module directory titles for the topic. Please output the detailed principle content of this title in detail. If there are code examples, please provide them according to standard code specifications. @@ -36,4 +41,5 @@ Strictly limit output according to the following requirements: 3. The output must be strictly in the specified language, {language}. 4. Do not have redundant output, including concluding remarks. 5. Strict requirement not to output the topic "{topic}". -""" \ No newline at end of file +""" +) diff --git a/metagpt/provider/anthropic_api.py b/metagpt/provider/anthropic_api.py index 7293e2cde..03802a716 100644 --- a/metagpt/provider/anthropic_api.py +++ b/metagpt/provider/anthropic_api.py @@ -32,4 +32,3 @@ class Claude2: max_tokens_to_sample=1000, ) return res.completion - \ No newline at end of file diff --git a/metagpt/provider/base_chatbot.py b/metagpt/provider/base_chatbot.py index abdf423f4..2d4cfe2d9 100644 --- a/metagpt/provider/base_chatbot.py +++ b/metagpt/provider/base_chatbot.py @@ -12,6 +12,7 @@ from dataclasses import dataclass @dataclass class BaseChatbot(ABC): """Abstract GPT class""" + mode: str = "API" @abstractmethod @@ -25,4 +26,3 @@ class BaseChatbot(ABC): @abstractmethod def ask_code(self, msgs: list) -> str: """Ask GPT multiple questions and get a piece of code""" - \ No newline at end of file diff --git a/metagpt/provider/base_gpt_api.py b/metagpt/provider/base_gpt_api.py index de61167b9..adc57c66b 100644 --- a/metagpt/provider/base_gpt_api.py +++ b/metagpt/provider/base_gpt_api.py @@ -14,7 +14,8 @@ from metagpt.provider.base_chatbot import BaseChatbot class BaseGPTAPI(BaseChatbot): """GPT API abstract class, requiring all inheritors to provide a series of standard capabilities""" - system_prompt = 'You are a helpful assistant.' + + system_prompt = "You are a helpful assistant." def _user_msg(self, msg: str) -> dict[str, str]: return {"role": "user", "content": msg} @@ -110,9 +111,8 @@ class BaseGPTAPI(BaseChatbot): def messages_to_prompt(self, messages: list[dict]): """[{"role": "user", "content": msg}] to user: etc.""" - return '\n'.join([f"{i['role']}: {i['content']}" for i in messages]) + return "\n".join([f"{i['role']}: {i['content']}" for i in messages]) def messages_to_dict(self, messages): """objects to [{"role": "user", "content": msg}] etc.""" return [i.to_dict() for i in messages] - \ No newline at end of file diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 6ebed2c16..ac0edd44f 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -110,7 +110,6 @@ class CostManager(metaclass=Singleton): """ return self.total_completion_tokens - def get_total_cost(self): """ Get the total cost of API calls. @@ -120,7 +119,6 @@ class CostManager(metaclass=Singleton): """ return self.total_cost - def get_costs(self) -> Costs: """Get all costs""" return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget) diff --git a/metagpt/provider/spark_api.py b/metagpt/provider/spark_api.py index 55f7000ec..60c86f4dc 100644 --- a/metagpt/provider/spark_api.py +++ b/metagpt/provider/spark_api.py @@ -14,8 +14,7 @@ import json import ssl from time import mktime from typing import Optional -from urllib.parse import urlencode -from urllib.parse import urlparse +from urllib.parse import urlencode, urlparse from wsgiref.handlers import format_date_time import websocket # 使用websocket_client @@ -26,9 +25,8 @@ from metagpt.provider.base_gpt_api import BaseGPTAPI class SparkAPI(BaseGPTAPI): - def __init__(self): - logger.warning('当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。') + logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。") def ask(self, msg: str) -> str: message = [self._default_system_msg(), self._user_msg(msg)] @@ -49,7 +47,7 @@ class SparkAPI(BaseGPTAPI): async def acompletion_text(self, messages: list[dict], stream=False) -> str: # 不支持 - logger.error('该功能禁用。') + logger.error("该功能禁用。") w = GetMessageFromWeb(messages) return w.run() @@ -93,29 +91,26 @@ class GetMessageFromWeb: signature_origin += "GET " + self.path + " HTTP/1.1" # 进行hmac-sha256进行加密 - signature_sha = hmac.new(self.api_secret.encode('utf-8'), signature_origin.encode('utf-8'), - digestmod=hashlib.sha256).digest() + signature_sha = hmac.new( + self.api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256 + ).digest() - signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') + signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8") authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' - authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') + authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8") # 将请求的鉴权参数组合为字典 - v = { - "authorization": authorization, - "date": date, - "host": self.host - } + v = {"authorization": authorization, "date": date, "host": self.host} # 拼接鉴权参数,生成url - url = self.spark_url + '?' + urlencode(v) + url = self.spark_url + "?" + urlencode(v) # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 return url def __init__(self, text): self.text = text - self.ret = '' + self.ret = "" self.spark_appid = CONFIG.spark_appid self.spark_api_secret = CONFIG.spark_api_secret self.spark_api_key = CONFIG.spark_api_key @@ -124,15 +119,15 @@ class GetMessageFromWeb: def on_message(self, ws, message): data = json.loads(message) - code = data['header']['code'] + code = data["header"]["code"] if code != 0: ws.close() # 请求错误,则关闭socket - logger.critical(f'回答获取失败,响应信息反序列化之后为: {data}') + logger.critical(f"回答获取失败,响应信息反序列化之后为: {data}") return else: choices = data["payload"]["choices"] - seq = choices["seq"] # 服务端是流式返回,seq为返回的数据序号 + # seq = choices["seq"] # 服务端是流式返回,seq为返回的数据序号 status = choices["status"] # 服务端是流式返回,status用于判断信息是否传送完毕 content = choices["text"][0]["content"] # 本次接收到的回答文本 self.ret += content @@ -142,7 +137,7 @@ class GetMessageFromWeb: # 收到websocket错误的处理 def on_error(self, ws, error): # on_message方法处理接收到的信息,出现任何错误,都会调用这个方法 - logger.critical(f'通讯连接出错,【错误提示: {error}】') + logger.critical(f"通讯连接出错,【错误提示: {error}】") # 收到websocket关闭的处理 def on_close(self, ws, one, two): @@ -150,17 +145,12 @@ class GetMessageFromWeb: # 处理请求数据 def gen_params(self): - data = { - "header": { - "app_id": self.spark_appid, - "uid": "1234" - }, + "header": {"app_id": self.spark_appid, "uid": "1234"}, "parameter": { "chat": { # domain为必传参数 "domain": self.domain, - # 以下为可微调,非必传参数 # 注意:官方建议,temperature和top_k修改一个即可 "max_tokens": 2048, # 默认2048,模型回答的tokens的最大长度,即允许它输出文本的最长字数 @@ -168,11 +158,7 @@ class GetMessageFromWeb: "top_k": 4, # 取值为[1,6],默认为4。从k个候选中随机选择一个(非等概率) } }, - "payload": { - "message": { - "text": self.text - } - } + "payload": {"message": {"text": self.text}}, } return data @@ -189,17 +175,12 @@ class GetMessageFromWeb: return self._run(self.text) def _run(self, text_list): - - ws_param = self.WsParam( - self.spark_appid, - self.spark_api_key, - self.spark_api_secret, - self.spark_url, - text_list) + ws_param = self.WsParam(self.spark_appid, self.spark_api_key, self.spark_api_secret, self.spark_url, text_list) ws_url = ws_param.create_url() websocket.enableTrace(False) # 默认禁用 WebSocket 的跟踪功能 - ws = websocket.WebSocketApp(ws_url, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, - on_open=self.on_open) + ws = websocket.WebSocketApp( + ws_url, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, on_open=self.on_open + ) ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) return self.ret diff --git a/metagpt/roles/customer_service.py b/metagpt/roles/customer_service.py index 4547f8190..188182d47 100644 --- a/metagpt/roles/customer_service.py +++ b/metagpt/roles/customer_service.py @@ -24,12 +24,5 @@ DESC = """ class CustomerService(Sales): - def __init__( - self, - name="Xiaomei", - profile="Human customer service", - desc=DESC, - store=None - ): + def __init__(self, name="Xiaomei", profile="Human customer service", desc=DESC, store=None): super().__init__(name, profile, desc=desc, store=store) - \ No newline at end of file diff --git a/metagpt/roles/invoice_ocr_assistant.py b/metagpt/roles/invoice_ocr_assistant.py index c307b20c0..3087a4da7 100644 --- a/metagpt/roles/invoice_ocr_assistant.py +++ b/metagpt/roles/invoice_ocr_assistant.py @@ -9,7 +9,7 @@ import pandas as pd -from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion +from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion from metagpt.prompts.invoice_ocr import INVOICE_OCR_SUCCESS from metagpt.roles import Role from metagpt.schema import Message @@ -107,4 +107,3 @@ class InvoiceOCRAssistant(Role): break msg = await self._act() return msg - diff --git a/metagpt/roles/prompt.py b/metagpt/roles/prompt.py index c22e0226b..457ccb6c6 100644 --- a/metagpt/roles/prompt.py +++ b/metagpt/roles/prompt.py @@ -23,6 +23,7 @@ SUFFIX = """Let's begin! Question: {input} Thoughts: {agent_scratchpad}""" + class PromptString(Enum): REFLECTION_QUESTIONS = "Here are some statements:\n{memory_descriptions}\n\nBased solely on the information above, what are the 3 most prominent high-level questions we can answer about the topic in the statements?\n\n{format_instructions}" @@ -32,7 +33,7 @@ class PromptString(Enum): RECENT_ACTIVITY = "Based on the following memory, produce a brief summary of what {full_name} has been up to recently. Do not invent details not explicitly stated in the memory. For any conversation, be sure to mention whether the conversation has concluded or is still ongoing.\n\nMemory: {memory_descriptions}" - MAKE_PLANS = "You are a plan-generating AI. Your job is to assist the character in formulating new plans based on new information. Given the character's information (profile, objectives, recent activities, current plans, and location context) and their current thought process, produce a new set of plans for them. The final plan should comprise at least {time_window} of activities and no more than 5 individual plans. List the plans in the order they should be executed, with each plan detailing its description, location, start time, stop criteria, and maximum duration.\n\nSample plan: {{\"index\": 1, \"description\": \"Cook dinner\", \"location_id\": \"0a3bc22b-36aa-48ab-adb0-18616004caed\",\"start_time\": \"2022-12-12T20:00:00+00:00\",\"max_duration_hrs\": 1.5, \"stop_condition\": \"Dinner is fully prepared\"}}\'\n\nFor each plan, choose the most appropriate location name from this list: {allowed_location_descriptions}\n\n{format_instructions}\n\nAlways prioritize completing any unfinished conversations.\n\nLet's begin!\n\nName: {full_name}\nProfile: {private_bio}\nObjectives: {directives}\nLocation Context: {location_context}\nCurrent Plans: {current_plans}\nRecent Activities: {recent_activity}\nThought Process: {thought_process}\nIt's essential to encourage the character to collaborate with other characters in their plans.\n\n" + MAKE_PLANS = 'You are a plan-generating AI. Your job is to assist the character in formulating new plans based on new information. Given the character\'s information (profile, objectives, recent activities, current plans, and location context) and their current thought process, produce a new set of plans for them. The final plan should comprise at least {time_window} of activities and no more than 5 individual plans. List the plans in the order they should be executed, with each plan detailing its description, location, start time, stop criteria, and maximum duration.\n\nSample plan: {{"index": 1, "description": "Cook dinner", "location_id": "0a3bc22b-36aa-48ab-adb0-18616004caed","start_time": "2022-12-12T20:00:00+00:00","max_duration_hrs": 1.5, "stop_condition": "Dinner is fully prepared"}}\'\n\nFor each plan, choose the most appropriate location name from this list: {allowed_location_descriptions}\n\n{format_instructions}\n\nAlways prioritize completing any unfinished conversations.\n\nLet\'s begin!\n\nName: {full_name}\nProfile: {private_bio}\nObjectives: {directives}\nLocation Context: {location_context}\nCurrent Plans: {current_plans}\nRecent Activities: {recent_activity}\nThought Process: {thought_process}\nIt\'s essential to encourage the character to collaborate with other characters in their plans.\n\n' EXECUTE_PLAN = "You are a role-playing AI, playing the role of {your_name}, in front of a live audience. Every word you say can be observed by the audience, so make sure you speak often and make it interesting. You cannot interact directly with the audience.\n\nBased on the following context and tools, proceed as if you are {your_name}. Your primary task is to complete the task given below, however, if you are currently in a conversation with another character, you should always complete the conversation first and then proceed with the task. Do not start working while you are engaged in an unfinished conversation. Use your best judgment to determine whether a conversation involves you and whether it is unfinished. You don't need to respond to every message you receive.\n\nThis is a personal profile of your character:\n{your_private_bio}\n\nThis is a description of your location and other nearby characters you can converse with:\n\n{location_context}\n\nThese are some relevant memories:\n```\n{relevant_memories}\n```\n\nThis is some relevant conversation history:\n```\n{conversation_history}\n```\n\nThese are the tools you can use:\n{tools}\n\nYour response should follow the format below:\n\nTask: The task you must complete\nThoughts: What you should think about doing\nAction: The action to take, must be one of these words: [{tool_names}]\nAction Input: Input for the action\nObservation: The result of the action\n... (This Thoughts/Action/Action Input/Observation can be repeated N times)\nThoughts: 'I have completed the task'\nFinal Response: The final response to the task\n\nIf you are not ready with a final response, then you must take an action.\n\nIf you are certain that you cannot complete the task with the tools provided, return 'Final Response: Need Help', however, if you are in conversation with another character, responses like 'I don't know' are valid. During conversations, you should never break character or admit you are an AI.\nIf the task is completed and no specific response is required, return 'Final Response: Done'\nLet's begin!\n\nTask: {input}\n\n{agent_scratchpad}" @@ -40,6 +41,6 @@ class PromptString(Enum): GOSSIP = "You are {full_name}. \n{memory_descriptions}\n\nBased on the statements above, say a thing or two of interest to others at your location: {other_agent_names}.\nAlways specify their names when referring to others." - HAS_HAPPENED = "Given the descriptions of the observations of the following characters and the events they are awaiting, indicate whether the character has witnessed the event.\n{format_instructions}\n\nExample:\n\nObservations:\nJoe entered the office at 2023-05-04 08:00:00+00:00\nJoe said hi to Sally at 2023-05-04 08:05:00+00:00\nSally said hello to Joe at 2023-05-04 08:05:30+00:00\nRebecca started working at 2023-05-04 08:10:00+00:00\nJoe made some breakfast at 2023-05-04 08:15:00+00:00\n\nAwaiting: Sally responded to Joe\n\nYour response: '{{\"has_happened\": true, \"date_occured\": 2023-05-04 08:05:30+00:00}}'\n\nLet's begin!\n\nObservations:\n{memory_descriptions}\n\nAwaiting: {event_description}\n" + HAS_HAPPENED = 'Given the descriptions of the observations of the following characters and the events they are awaiting, indicate whether the character has witnessed the event.\n{format_instructions}\n\nExample:\n\nObservations:\nJoe entered the office at 2023-05-04 08:00:00+00:00\nJoe said hi to Sally at 2023-05-04 08:05:00+00:00\nSally said hello to Joe at 2023-05-04 08:05:30+00:00\nRebecca started working at 2023-05-04 08:10:00+00:00\nJoe made some breakfast at 2023-05-04 08:15:00+00:00\n\nAwaiting: Sally responded to Joe\n\nYour response: \'{{"has_happened": true, "date_occured": 2023-05-04 08:05:30+00:00}}\'\n\nLet\'s begin!\n\nObservations:\n{memory_descriptions}\n\nAwaiting: {event_description}\n' OUTPUT_FORMAT = "\n\n(Remember! Make sure your output always adheres to one of the following two formats:\n\nA. If you have completed the task:\nThoughts: 'I have completed the task'\nFinal Response: \n\nB. If you haven't completed the task:\nThoughts: \nAction: \nAction Input: \nObservation: )\n" diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 44bb3e976..282431bf7 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -11,12 +11,13 @@ from typing import Iterable, Type from pydantic import BaseModel, Field +from metagpt.actions import Action, ActionOutput + # from metagpt.environment import Environment from metagpt.config import CONFIG -from metagpt.actions import Action, ActionOutput from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.memory import Memory, LongTermMemory +from metagpt.memory import LongTermMemory, Memory from metagpt.schema import Message PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """ @@ -49,6 +50,7 @@ ROLE_TEMPLATE = """Your response should be based on the previous conversation hi class RoleSetting(BaseModel): """Role Settings""" + name: str profile: str goal: str @@ -64,7 +66,8 @@ class RoleSetting(BaseModel): class RoleContext(BaseModel): """Role Runtime Context""" - env: 'Environment' = Field(default=None) + + env: "Environment" = Field(default=None) memory: Memory = Field(default_factory=Memory) long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory) state: int = Field(default=0) @@ -128,7 +131,7 @@ class Role: logger.debug(self._actions) self._rc.todo = self._actions[self._rc.state] - def set_env(self, env: 'Environment'): + def set_env(self, env: "Environment"): """Set the environment in which the role works. The role can talk to the environment and can also receive messages by observing.""" self._rc.env = env @@ -150,12 +153,13 @@ class Role: self._set_state(0) return prompt = self._get_prefix() - prompt += STATE_TEMPLATE.format(history=self._rc.history, states="\n".join(self._states), - n_states=len(self._states) - 1) + prompt += STATE_TEMPLATE.format( + history=self._rc.history, states="\n".join(self._states), n_states=len(self._states) - 1 + ) next_state = await self._llm.aask(prompt) logger.debug(f"{prompt=}") if not next_state.isdigit() or int(next_state) not in range(len(self._states)): - logger.warning(f'Invalid answer of state, {next_state=}') + logger.warning(f"Invalid answer of state, {next_state=}") next_state = "0" self._set_state(int(next_state)) @@ -168,8 +172,12 @@ class Role: response = await self._rc.todo.run(self._rc.important_memory) # logger.info(response) if isinstance(response, ActionOutput): - msg = Message(content=response.content, instruct_content=response.instruct_content, - role=self.profile, cause_by=type(self._rc.todo)) + msg = Message( + content=response.content, + instruct_content=response.instruct_content, + role=self.profile, + cause_by=type(self._rc.todo), + ) else: msg = Message(content=response, role=self.profile, cause_by=type(self._rc.todo)) self._rc.memory.add(msg) @@ -184,15 +192,17 @@ class Role: env_msgs = self._rc.env.memory.get() observed = self._rc.env.memory.get_by_actions(self._rc.watch) - - self._rc.news = self._rc.memory.find_news(observed) # find news (previously unseen messages) from observed messages + + self._rc.news = self._rc.memory.find_news( + observed + ) # find news (previously unseen messages) from observed messages for i in env_msgs: self.recv(i) news_text = [f"{i.role}: {i.content[:20]}..." for i in self._rc.news] if news_text: - logger.debug(f'{self._setting} observed: {news_text}') + logger.debug(f"{self._setting} observed: {news_text}") return len(self._rc.news) def _publish_message(self, msg): diff --git a/metagpt/roles/sales.py b/metagpt/roles/sales.py index a45ad6f1b..18282a494 100644 --- a/metagpt/roles/sales.py +++ b/metagpt/roles/sales.py @@ -12,16 +12,16 @@ from metagpt.tools import SearchEngineType class Sales(Role): def __init__( - self, - name="Xiaomei", - profile="Retail sales guide", - desc="I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I " - "will answer questions only based on the information in the knowledge base." - "If I feel that you can't get the answer from the reference material, then I will directly reply that" - " I don't know, and I won't tell you that this is from the knowledge base," - "but pretend to be what I know. Note that each of my replies will be replied in the tone of a " - "professional guide", - store=None + self, + name="Xiaomei", + profile="Retail sales guide", + desc="I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I " + "will answer questions only based on the information in the knowledge base." + "If I feel that you can't get the answer from the reference material, then I will directly reply that" + " I don't know, and I won't tell you that this is from the knowledge base," + "but pretend to be what I know. Note that each of my replies will be replied in the tone of a " + "professional guide", + store=None, ): super().__init__(name, profile, desc=desc) self._set_store(store) @@ -32,4 +32,3 @@ class Sales(Role): else: action = SearchAndSummarize() self._init_actions([action]) - \ No newline at end of file diff --git a/metagpt/roles/seacher.py b/metagpt/roles/seacher.py index 0b6e089da..a2c4896e2 100644 --- a/metagpt/roles/seacher.py +++ b/metagpt/roles/seacher.py @@ -15,7 +15,7 @@ from metagpt.tools import SearchEngineType class Searcher(Role): """ Represents a Searcher role responsible for providing search services to users. - + Attributes: name (str): Name of the searcher. profile (str): Role profile. @@ -23,17 +23,19 @@ class Searcher(Role): constraints (str): Constraints or limitations for the searcher. engine (SearchEngineType): The type of search engine to use. """ - - def __init__(self, - name: str = 'Alice', - profile: str = 'Smart Assistant', - goal: str = 'Provide search services for users', - constraints: str = 'Answer is rich and complete', - engine=SearchEngineType.SERPAPI_GOOGLE, - **kwargs) -> None: + + def __init__( + self, + name: str = "Alice", + profile: str = "Smart Assistant", + goal: str = "Provide search services for users", + constraints: str = "Answer is rich and complete", + engine=SearchEngineType.SERPAPI_GOOGLE, + **kwargs, + ) -> None: """ Initializes the Searcher role with given attributes. - + Args: name (str): Name of the searcher. profile (str): Role profile. @@ -53,10 +55,14 @@ class Searcher(Role): """Performs the search action in a single process.""" logger.info(f"{self._setting}: ready to {self._rc.todo}") response = await self._rc.todo.run(self._rc.memory.get(k=0)) - + if isinstance(response, ActionOutput): - msg = Message(content=response.content, instruct_content=response.instruct_content, - role=self.profile, cause_by=type(self._rc.todo)) + msg = Message( + content=response.content, + instruct_content=response.instruct_content, + role=self.profile, + cause_by=type(self._rc.todo), + ) else: msg = Message(content=response, role=self.profile, cause_by=type(self._rc.todo)) self._rc.memory.add(msg) diff --git a/metagpt/roles/tutorial_assistant.py b/metagpt/roles/tutorial_assistant.py index 9a7df4f4d..2a514f433 100644 --- a/metagpt/roles/tutorial_assistant.py +++ b/metagpt/roles/tutorial_assistant.py @@ -9,7 +9,7 @@ from datetime import datetime from typing import Dict -from metagpt.actions.write_tutorial import WriteDirectory, WriteContent +from metagpt.actions.write_tutorial import WriteContent, WriteDirectory from metagpt.const import TUTORIAL_PATH from metagpt.logs import logger from metagpt.roles import Role @@ -110,5 +110,5 @@ class TutorialAssistant(Role): break msg = await self._act() root_path = TUTORIAL_PATH / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - await File.write(root_path, f"{self.main_title}.md", self.total_content.encode('utf-8')) + await File.write(root_path, f"{self.main_title}.md", self.total_content.encode("utf-8")) return msg diff --git a/metagpt/schema.py b/metagpt/schema.py index bdca093c2..19c7a6654 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -23,9 +23,10 @@ class RawMessage(TypedDict): @dataclass class Message: """list[: ]""" + content: str instruct_content: BaseModel = field(default=None) - role: str = field(default='user') # system / user / assistant + role: str = field(default="user") # system / user / assistant cause_by: Type["Action"] = field(default="") sent_from: str = field(default="") send_to: str = field(default="") @@ -39,45 +40,45 @@ class Message: return self.__str__() def to_dict(self) -> dict: - return { - "role": self.role, - "content": self.content - } + return {"role": self.role, "content": self.content} @dataclass class UserMessage(Message): """便于支持OpenAI的消息 - Facilitate support for OpenAI messages + Facilitate support for OpenAI messages """ + def __init__(self, content: str): - super().__init__(content, 'user') + super().__init__(content, "user") @dataclass class SystemMessage(Message): """便于支持OpenAI的消息 - Facilitate support for OpenAI messages + Facilitate support for OpenAI messages """ + def __init__(self, content: str): - super().__init__(content, 'system') + super().__init__(content, "system") @dataclass class AIMessage(Message): """便于支持OpenAI的消息 - Facilitate support for OpenAI messages + Facilitate support for OpenAI messages """ + def __init__(self, content: str): - super().__init__(content, 'assistant') + super().__init__(content, "assistant") -if __name__ == '__main__': - test_content = 'test_message' +if __name__ == "__main__": + test_content = "test_message" msgs = [ UserMessage(test_content), SystemMessage(test_content), AIMessage(test_content), - Message(test_content, role='QA') + Message(test_content, role="QA"), ] logger.info(msgs) diff --git a/metagpt/software_company.py b/metagpt/software_company.py index b2bd18c58..d3c2c463b 100644 --- a/metagpt/software_company.py +++ b/metagpt/software_company.py @@ -21,6 +21,7 @@ class SoftwareCompany(BaseModel): Software Company: Possesses a team, SOP (Standard Operating Procedures), and a platform for instant messaging, dedicated to writing executable code. """ + environment: Environment = Field(default_factory=Environment) investment: float = Field(default=10.0) idea: str = Field(default="") @@ -36,11 +37,11 @@ class SoftwareCompany(BaseModel): """Invest company. raise NoMoneyException when exceed max_budget.""" self.investment = investment CONFIG.max_budget = investment - logger.info(f'Investment: ${investment}.') + logger.info(f"Investment: ${investment}.") def _check_balance(self): if CONFIG.total_cost > CONFIG.max_budget: - raise NoMoneyException(CONFIG.total_cost, f'Insufficient funds: {CONFIG.max_budget}') + raise NoMoneyException(CONFIG.total_cost, f"Insufficient funds: {CONFIG.max_budget}") def start_project(self, idea): """Start a project from publishing boss requirement.""" @@ -59,4 +60,3 @@ class SoftwareCompany(BaseModel): self._check_balance() await self.environment.run() return self.environment.history - \ No newline at end of file diff --git a/metagpt/tools/code_interpreter.py b/metagpt/tools/code_interpreter.py index e41eaab72..1cba005fa 100644 --- a/metagpt/tools/code_interpreter.py +++ b/metagpt/tools/code_interpreter.py @@ -1,22 +1,26 @@ +import inspect import re -from typing import List, Callable, Dict +import textwrap from pathlib import Path +from typing import Callable, Dict, List import wrapt -import textwrap -import inspect from interpreter.core.core import Interpreter -from metagpt.logs import logger +from metagpt.actions.clone_function import ( + CloneFunction, + run_function_code, + run_function_script, +) from metagpt.config import CONFIG +from metagpt.logs import logger from metagpt.utils.highlight import highlight -from metagpt.actions.clone_function import CloneFunction, run_function_code, run_function_script def extract_python_code(code: str): """Extract code blocks: If the code comments are the same, only the last code block is kept.""" # Use regular expressions to match comment blocks and related code. - pattern = r'(#\s[^\n]*)\n(.*?)(?=\n\s*#|$)' + pattern = r"(#\s[^\n]*)\n(.*?)(?=\n\s*#|$)" matches = re.findall(pattern, code, re.DOTALL) # Extract the last code block when encountering the same comment. @@ -25,8 +29,8 @@ def extract_python_code(code: str): unique_comments[comment] = code_block # concatenate into functional form - result_code = '\n'.join([f"{comment}\n{code_block}" for comment, code_block in unique_comments.items()]) - header_code = code[:code.find("#")] + result_code = "\n".join([f"{comment}\n{code_block}" for comment, code_block in unique_comments.items()]) + header_code = code[: code.find("#")] code = header_code + result_code logger.info(f"Extract python code: \n {highlight(code)}") @@ -36,6 +40,7 @@ def extract_python_code(code: str): class OpenCodeInterpreter(object): """https://github.com/KillianLucas/open-interpreter""" + def __init__(self, auto_run: bool = True) -> None: interpreter = Interpreter() interpreter.auto_run = auto_run @@ -50,15 +55,16 @@ class OpenCodeInterpreter(object): return self.interpreter.chat(query) @staticmethod - def extract_function(query_respond: List, function_name: str, *, language: str = 'python', - function_format: str = None) -> str: + def extract_function( + query_respond: List, function_name: str, *, language: str = "python", function_format: str = None + ) -> str: """create a function from query_respond.""" - if language not in ('python'): + if language not in ("python"): raise NotImplementedError(f"Not support to parse language {language}!") # set function form if function_format is None: - assert language == 'python', f"Expect python language for default function_format, but got {language}." + assert language == "python", f"Expect python language for default function_format, but got {language}." function_format = """def {function_name}():\n{code}""" # Extract the code module in the open-interpreter respond message. # The query_respond of open-interpreter before v0.1.4 is: @@ -68,25 +74,29 @@ class OpenCodeInterpreter(object): # "parsed_arguments": {"language": "python", "code": code of first plan} # ...] if "function_call" in query_respond[1]: - code = [item['function_call']['parsed_arguments']['code'] for item in query_respond - if "function_call" in item - and "parsed_arguments" in item["function_call"] - and 'language' in item["function_call"]['parsed_arguments'] - and item["function_call"]['parsed_arguments']['language'] == language] + code = [ + item["function_call"]["parsed_arguments"]["code"] + for item in query_respond + if "function_call" in item + and "parsed_arguments" in item["function_call"] + and "language" in item["function_call"]["parsed_arguments"] + and item["function_call"]["parsed_arguments"]["language"] == language + ] # The query_respond of open-interpreter v0.1.7 is: # [{'role': 'user', 'message': your query string}, # {'role': 'assistant', 'message': plan from llm, 'language': 'python', # 'code': code of first plan, 'output': output of first plan code}, # ...] elif "code" in query_respond[1]: - code = [item['code'] for item in query_respond - if "code" in item - and 'language' in item - and item['language'] == language] + code = [ + item["code"] + for item in query_respond + if "code" in item and "language" in item and item["language"] == language + ] else: raise ValueError(f"Unexpect message format in query_respond: {query_respond[1].keys()}") # add indent. - indented_code_str = textwrap.indent("\n".join(code), ' ' * 4) + indented_code_str = textwrap.indent("\n".join(code), " " * 4) # Return the code after deduplication. if language == "python": return extract_python_code(function_format.format(function_name=function_name, code=indented_code_str)) @@ -115,13 +125,13 @@ class OpenInterpreterDecorator(object): def _have_code(self, rsp: List[Dict]): # Is there any code generated? - return 'code' in rsp[1] and rsp[1]['code'] not in ("", None) + return "code" in rsp[1] and rsp[1]["code"] not in ("", None) def _is_faild_plan(self, rsp: List[Dict]): # is faild plan? - func_code = OpenCodeInterpreter.extract_function(rsp, 'function') + func_code = OpenCodeInterpreter.extract_function(rsp, "function") # If there is no more than 1 '\n', the plan execution fails. - if isinstance(func_code, str) and func_code.count('\n') <= 1: + if isinstance(func_code, str) and func_code.count("\n") <= 1: return True return False @@ -184,4 +194,5 @@ class OpenInterpreterDecorator(object): logger.error(f"Could not evaluate Python code \n{logger_code}: \nError: {e}") raise Exception("Could not evaluate Python code", e) return res + return wrapper(wrapped) diff --git a/metagpt/tools/prompt_writer.py b/metagpt/tools/prompt_writer.py index d90599206..ffcff4d1f 100644 --- a/metagpt/tools/prompt_writer.py +++ b/metagpt/tools/prompt_writer.py @@ -10,8 +10,9 @@ from typing import Union class GPTPromptGenerator: """Using LLM, given an output, request LLM to provide input (supporting instruction, chatbot, and query styles)""" + def __init__(self): - self._generators = {i: getattr(self, f"gen_{i}_style") for i in ['instruction', 'chatbot', 'query']} + self._generators = {i: getattr(self, f"gen_{i}_style") for i in ["instruction", "chatbot", "query"]} def gen_instruction_style(self, example): """Instruction style: Given an output, request LLM to provide input""" @@ -35,7 +36,7 @@ Query: X Document: {example} What is the detailed query X? X:""" - def gen(self, example: str, style: str = 'all') -> Union[list[str], str]: + def gen(self, example: str, style: str = "all") -> Union[list[str], str]: """ Generate one or multiple outputs using the example, allowing LLM to reply with the corresponding input @@ -43,7 +44,7 @@ X:""" :param style: (all|instruction|chatbot|query) :return: Expected LLM input sample (one or multiple) """ - if style != 'all': + if style != "all": return self._generators[style](example) return [f(example) for f in self._generators.values()] diff --git a/metagpt/tools/sd_engine.py b/metagpt/tools/sd_engine.py index 1d9cd0b2a..a63dbe5ac 100644 --- a/metagpt/tools/sd_engine.py +++ b/metagpt/tools/sd_engine.py @@ -120,11 +120,13 @@ def decode_base64_to_image(img, save_name): image.save(f"{save_name}.png", pnginfo=pnginfo) return pnginfo, image + def batch_decode_base64_to_image(imgs, save_dir="", save_name=""): for idx, _img in enumerate(imgs): save_name = join(save_dir, save_name) decode_base64_to_image(_img, save_name=save_name) + if __name__ == "__main__": engine = SDEngine() prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary" diff --git a/metagpt/tools/search_engine.py b/metagpt/tools/search_engine.py index 942ef7edd..64388a11f 100644 --- a/metagpt/tools/search_engine.py +++ b/metagpt/tools/search_engine.py @@ -6,7 +6,7 @@ @File : search_engine.py """ import importlib -from typing import Callable, Coroutine, Literal, overload, Optional, Union +from typing import Callable, Coroutine, Literal, Optional, Union, overload from semantic_kernel.skill_definition import sk_function @@ -43,8 +43,8 @@ class SearchEngine: def __init__( self, - engine: Optional[SearchEngineType] = None, - run_func: Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]] = None, + engine: Optional[SearchEngineType] = None, + run_func: Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]] = None, ): engine = engine or CONFIG.search_engine if engine == SearchEngineType.SERPAPI_GOOGLE: diff --git a/metagpt/tools/search_engine_meilisearch.py b/metagpt/tools/search_engine_meilisearch.py index da4269384..f7c1c685a 100644 --- a/metagpt/tools/search_engine_meilisearch.py +++ b/metagpt/tools/search_engine_meilisearch.py @@ -29,7 +29,7 @@ class MeilisearchEngine: def add_documents(self, data_source: DataSource, documents: List[dict]): index_name = f"{data_source.name}_index" if index_name not in self.client.get_indexes(): - self.client.create_index(uid=index_name, options={'primaryKey': 'id'}) + self.client.create_index(uid=index_name, options={"primaryKey": "id"}) index = self.client.get_index(index_name) index.add_documents(documents) self.set_index(index) @@ -37,7 +37,7 @@ class MeilisearchEngine: def search(self, query): try: search_results = self._index.search(query) - return search_results['hits'] + return search_results["hits"] except Exception as e: # Handle MeiliSearch API errors print(f"MeiliSearch API error: {e}") diff --git a/metagpt/tools/translator.py b/metagpt/tools/translator.py index 910638469..63e38d5a5 100644 --- a/metagpt/tools/translator.py +++ b/metagpt/tools/translator.py @@ -6,7 +6,7 @@ @File : translator.py """ -prompt = ''' +prompt = """ # 指令 接下来,作为一位拥有20年翻译经验的翻译专家,当我给出英文句子或段落时,你将提供通顺且具有可读性的{LANG}翻译。注意以下要求: 1. 确保翻译结果流畅且易于理解 @@ -17,11 +17,10 @@ prompt = ''' {ORIGINAL} # 译文 -''' +""" class Translator: - @classmethod - def translate_prompt(cls, original, lang='中文'): - return prompt.format(LANG=lang, ORIGINAL=original) \ No newline at end of file + def translate_prompt(cls, original, lang="中文"): + return prompt.format(LANG=lang, ORIGINAL=original) diff --git a/metagpt/tools/ut_writer.py b/metagpt/tools/ut_writer.py index 43ca72150..64423dfb1 100644 --- a/metagpt/tools/ut_writer.py +++ b/metagpt/tools/ut_writer.py @@ -6,7 +6,7 @@ from pathlib import Path from metagpt.provider.openai_api import OpenAIGPTAPI as GPTAPI -ICL_SAMPLE = '''Interface definition: +ICL_SAMPLE = """Interface definition: ```text Interface Name: Element Tagging Interface Path: /projects/{project_key}/node-tags @@ -60,20 +60,20 @@ def test_node_tags(project_key, nodes, operations, expected_msg): # 3. If comments are needed, use Chinese. # If you understand, please wait for me to give the interface definition and just answer "Understood" to save tokens. -''' +""" -ACT_PROMPT_PREFIX = '''Refer to the test types: such as missing request parameters, field boundary verification, incorrect field type. +ACT_PROMPT_PREFIX = """Refer to the test types: such as missing request parameters, field boundary verification, incorrect field type. Please output 10 test cases within one `@pytest.mark.parametrize` scope. ```text -''' +""" -YFT_PROMPT_PREFIX = '''Refer to the test types: such as SQL injection, cross-site scripting (XSS), unauthorized access and privilege escalation, +YFT_PROMPT_PREFIX = """Refer to the test types: such as SQL injection, cross-site scripting (XSS), unauthorized access and privilege escalation, authentication and authorization, parameter verification, exception handling, file upload and download. Please output 10 test cases within one `@pytest.mark.parametrize` scope. ```text -''' +""" -OCR_API_DOC = '''```text +OCR_API_DOC = """```text Interface Name: OCR recognition Interface Path: /api/v1/contract/treaty/task/ocr Method: POST @@ -96,14 +96,20 @@ code integer Yes message string Yes data object Yes ``` -''' +""" class UTGenerator: """UT Generator: Construct UT through API documentation""" - def __init__(self, swagger_file: str, ut_py_path: str, questions_path: str, - chatgpt_method: str = "API", template_prefix=YFT_PROMPT_PREFIX) -> None: + def __init__( + self, + swagger_file: str, + ut_py_path: str, + questions_path: str, + chatgpt_method: str = "API", + template_prefix=YFT_PROMPT_PREFIX, + ) -> None: """Initialize UT Generator Args: @@ -274,7 +280,7 @@ class UTGenerator: def gpt_msgs_to_code(self, messages: list) -> str: """Choose based on different calling methods""" - result = '' + result = "" if self.chatgpt_method == "API": result = GPTAPI().ask_code(msgs=messages) diff --git a/metagpt/utils/file.py b/metagpt/utils/file.py index f3691549b..6bb9a1a97 100644 --- a/metagpt/utils/file.py +++ b/metagpt/utils/file.py @@ -6,9 +6,10 @@ @File : file.py @Describe : General file operations. """ -import aiofiles from pathlib import Path +import aiofiles + from metagpt.logs import logger @@ -66,10 +67,9 @@ class File: if not chunk: break chunks.append(chunk) - content = b''.join(chunks) + content = b"".join(chunks) logger.debug(f"Successfully read file, the path of file: {file_path}") return content except Exception as e: logger.error(f"Error reading file: {e}") raise e - diff --git a/metagpt/utils/highlight.py b/metagpt/utils/highlight.py index e6cbb228c..2e1d6f615 100644 --- a/metagpt/utils/highlight.py +++ b/metagpt/utils/highlight.py @@ -1,22 +1,22 @@ # 添加代码语法高亮显示 from pygments import highlight as highlight_ +from pygments.formatters import HtmlFormatter, TerminalFormatter from pygments.lexers import PythonLexer, SqlLexer -from pygments.formatters import TerminalFormatter, HtmlFormatter -def highlight(code: str, language: str = 'python', formatter: str = 'terminal'): +def highlight(code: str, language: str = "python", formatter: str = "terminal"): # 指定要高亮的语言 - if language.lower() == 'python': + if language.lower() == "python": lexer = PythonLexer() - elif language.lower() == 'sql': + elif language.lower() == "sql": lexer = SqlLexer() else: raise ValueError(f"Unsupported language: {language}") # 指定输出格式 - if formatter.lower() == 'terminal': + if formatter.lower() == "terminal": formatter = TerminalFormatter() - elif formatter.lower() == 'html': + elif formatter.lower() == "html": formatter = HtmlFormatter() else: raise ValueError(f"Unsupported formatter: {formatter}") diff --git a/metagpt/utils/mmdc_ink.py b/metagpt/utils/mmdc_ink.py index 3d91cde9d..d594adb30 100644 --- a/metagpt/utils/mmdc_ink.py +++ b/metagpt/utils/mmdc_ink.py @@ -6,9 +6,9 @@ @File : mermaid.py """ import base64 -import os -from aiohttp import ClientSession,ClientError +from aiohttp import ClientError, ClientSession + from metagpt.logs import logger @@ -29,7 +29,7 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix): async with session.get(url) as response: if response.status == 200: text = await response.content.read() - with open(output_file, 'wb') as f: + with open(output_file, "wb") as f: f.write(text) logger.info(f"Generating {output_file}..") else: diff --git a/metagpt/utils/mmdc_playwright.py b/metagpt/utils/mmdc_playwright.py index bdbfd82ff..5d455e1c5 100644 --- a/metagpt/utils/mmdc_playwright.py +++ b/metagpt/utils/mmdc_playwright.py @@ -8,10 +8,13 @@ import os from urllib.parse import urljoin + from playwright.async_api import async_playwright + from metagpt.logs import logger -async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048)-> int: + +async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int: """ Converts the given Mermaid code to various output formats and saves them to files. @@ -24,66 +27,72 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, Returns: int: Returns 1 if the conversion and saving were successful, -1 otherwise. """ - suffixes=['png', 'svg', 'pdf'] + suffixes = ["png", "svg", "pdf"] __dirname = os.path.dirname(os.path.abspath(__file__)) async with async_playwright() as p: browser = await p.chromium.launch() device_scale_factor = 1.0 context = await browser.new_context( - viewport={'width': width, 'height': height}, - device_scale_factor=device_scale_factor, - ) + viewport={"width": width, "height": height}, + device_scale_factor=device_scale_factor, + ) page = await context.new_page() async def console_message(msg): logger.info(msg.text) - page.on('console', console_message) + + page.on("console", console_message) try: - await page.set_viewport_size({'width': width, 'height': height}) + await page.set_viewport_size({"width": width, "height": height}) - mermaid_html_path = os.path.abspath( - os.path.join(__dirname, 'index.html')) - mermaid_html_url = urljoin('file:', mermaid_html_path) + mermaid_html_path = os.path.abspath(os.path.join(__dirname, "index.html")) + mermaid_html_url = urljoin("file:", mermaid_html_path) await page.goto(mermaid_html_url) await page.wait_for_load_state("networkidle") await page.wait_for_selector("div#container", state="attached") - mermaid_config = {} + # mermaid_config = {} background_color = "#ffffff" - my_css = "" + # my_css = "" await page.evaluate(f'document.body.style.background = "{background_color}";') - metadata = await page.evaluate('''async ([definition, mermaidConfig, myCSS, backgroundColor]) => { - const { mermaid, zenuml } = globalThis; - await mermaid.registerExternalDiagrams([zenuml]); - mermaid.initialize({ startOnLoad: false, ...mermaidConfig }); - const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container')); - document.getElementById('container').innerHTML = svg; - const svgElement = document.querySelector('svg'); - svgElement.style.backgroundColor = backgroundColor; + # metadata = await page.evaluate( + # """async ([definition, mermaidConfig, myCSS, backgroundColor]) => { + # const { mermaid, zenuml } = globalThis; + # await mermaid.registerExternalDiagrams([zenuml]); + # mermaid.initialize({ startOnLoad: false, ...mermaidConfig }); + # const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container')); + # document.getElementById('container').innerHTML = svg; + # const svgElement = document.querySelector('svg'); + # svgElement.style.backgroundColor = backgroundColor; + # + # if (myCSS) { + # const style = document.createElementNS('http://www.w3.org/2000/svg', 'style'); + # style.appendChild(document.createTextNode(myCSS)); + # svgElement.appendChild(style); + # } + # + # }""", + # [mermaid_code, mermaid_config, my_css, background_color], + # ) - if (myCSS) { - const style = document.createElementNS('http://www.w3.org/2000/svg', 'style'); - style.appendChild(document.createTextNode(myCSS)); - svgElement.appendChild(style); - } - - }''', [mermaid_code, mermaid_config, my_css, background_color]) - - if 'svg' in suffixes : - svg_xml = await page.evaluate('''() => { + if "svg" in suffixes: + svg_xml = await page.evaluate( + """() => { const svg = document.querySelector('svg'); const xmlSerializer = new XMLSerializer(); return xmlSerializer.serializeToString(svg); - }''') + }""" + ) logger.info(f"Generating {output_file_without_suffix}.svg..") - with open(f'{output_file_without_suffix}.svg', 'wb') as f: - f.write(svg_xml.encode('utf-8')) + with open(f"{output_file_without_suffix}.svg", "wb") as f: + f.write(svg_xml.encode("utf-8")) - if 'png' in suffixes: - clip = await page.evaluate('''() => { + if "png" in suffixes: + clip = await page.evaluate( + """() => { const svg = document.querySelector('svg'); const rect = svg.getBoundingClientRect(); return { @@ -92,16 +101,17 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, width: Math.ceil(rect.width), height: Math.ceil(rect.height) }; - }''') - await page.set_viewport_size({'width': clip['x'] + clip['width'], 'height': clip['y'] + clip['height']}) - screenshot = await page.screenshot(clip=clip, omit_background=True, scale='device') + }""" + ) + await page.set_viewport_size({"width": clip["x"] + clip["width"], "height": clip["y"] + clip["height"]}) + screenshot = await page.screenshot(clip=clip, omit_background=True, scale="device") logger.info(f"Generating {output_file_without_suffix}.png..") - with open(f'{output_file_without_suffix}.png', 'wb') as f: + with open(f"{output_file_without_suffix}.png", "wb") as f: f.write(screenshot) - if 'pdf' in suffixes: + if "pdf" in suffixes: pdf_data = await page.pdf(scale=device_scale_factor) logger.info(f"Generating {output_file_without_suffix}.pdf..") - with open(f'{output_file_without_suffix}.pdf', 'wb') as f: + with open(f"{output_file_without_suffix}.pdf", "wb") as f: f.write(pdf_data) return 0 except Exception as e: diff --git a/metagpt/utils/mmdc_pyppeteer.py b/metagpt/utils/mmdc_pyppeteer.py index 7ec30fd12..7125cafc5 100644 --- a/metagpt/utils/mmdc_pyppeteer.py +++ b/metagpt/utils/mmdc_pyppeteer.py @@ -7,11 +7,14 @@ """ import os from urllib.parse import urljoin -from pyppeteer import launch -from metagpt.logs import logger -from metagpt.config import CONFIG -async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048)-> int: +from pyppeteer import launch + +from metagpt.config import CONFIG +from metagpt.logs import logger + + +async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int: """ Converts the given Mermaid code to various output formats and saves them to files. @@ -24,15 +27,15 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, Returns: int: Returns 1 if the conversion and saving were successful, -1 otherwise. """ - suffixes = ['png', 'svg', 'pdf'] + suffixes = ["png", "svg", "pdf"] __dirname = os.path.dirname(os.path.abspath(__file__)) - if CONFIG.pyppeteer_executable_path: - browser = await launch(headless=True, - executablePath=CONFIG.pyppeteer_executable_path, - args=['--disable-extensions',"--no-sandbox"] - ) + browser = await launch( + headless=True, + executablePath=CONFIG.pyppeteer_executable_path, + args=["--disable-extensions", "--no-sandbox"], + ) else: logger.error("Please set the environment variable:PYPPETEER_EXECUTABLE_PATH.") return -1 @@ -41,50 +44,56 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, async def console_message(msg): logger.info(msg.text) - page.on('console', console_message) + + page.on("console", console_message) try: - await page.setViewport(viewport={'width': width, 'height': height, 'deviceScaleFactor': device_scale_factor}) + await page.setViewport(viewport={"width": width, "height": height, "deviceScaleFactor": device_scale_factor}) - mermaid_html_path = os.path.abspath( - os.path.join(__dirname, 'index.html')) - mermaid_html_url = urljoin('file:', mermaid_html_path) + mermaid_html_path = os.path.abspath(os.path.join(__dirname, "index.html")) + mermaid_html_url = urljoin("file:", mermaid_html_path) await page.goto(mermaid_html_url) await page.querySelector("div#container") - mermaid_config = {} + # mermaid_config = {} background_color = "#ffffff" - my_css = "" + # my_css = "" await page.evaluate(f'document.body.style.background = "{background_color}";') - metadata = await page.evaluate('''async ([definition, mermaidConfig, myCSS, backgroundColor]) => { - const { mermaid, zenuml } = globalThis; - await mermaid.registerExternalDiagrams([zenuml]); - mermaid.initialize({ startOnLoad: false, ...mermaidConfig }); - const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container')); - document.getElementById('container').innerHTML = svg; - const svgElement = document.querySelector('svg'); - svgElement.style.backgroundColor = backgroundColor; + # metadata = await page.evaluate( + # """async ([definition, mermaidConfig, myCSS, backgroundColor]) => { + # const { mermaid, zenuml } = globalThis; + # await mermaid.registerExternalDiagrams([zenuml]); + # mermaid.initialize({ startOnLoad: false, ...mermaidConfig }); + # const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container')); + # document.getElementById('container').innerHTML = svg; + # const svgElement = document.querySelector('svg'); + # svgElement.style.backgroundColor = backgroundColor; + # + # if (myCSS) { + # const style = document.createElementNS('http://www.w3.org/2000/svg', 'style'); + # style.appendChild(document.createTextNode(myCSS)); + # svgElement.appendChild(style); + # } + # }""", + # [mermaid_code, mermaid_config, my_css, background_color], + # ) - if (myCSS) { - const style = document.createElementNS('http://www.w3.org/2000/svg', 'style'); - style.appendChild(document.createTextNode(myCSS)); - svgElement.appendChild(style); - } - }''', [mermaid_code, mermaid_config, my_css, background_color]) - - if 'svg' in suffixes : - svg_xml = await page.evaluate('''() => { + if "svg" in suffixes: + svg_xml = await page.evaluate( + """() => { const svg = document.querySelector('svg'); const xmlSerializer = new XMLSerializer(); return xmlSerializer.serializeToString(svg); - }''') + }""" + ) logger.info(f"Generating {output_file_without_suffix}.svg..") - with open(f'{output_file_without_suffix}.svg', 'wb') as f: - f.write(svg_xml.encode('utf-8')) + with open(f"{output_file_without_suffix}.svg", "wb") as f: + f.write(svg_xml.encode("utf-8")) - if 'png' in suffixes: - clip = await page.evaluate('''() => { + if "png" in suffixes: + clip = await page.evaluate( + """() => { const svg = document.querySelector('svg'); const rect = svg.getBoundingClientRect(); return { @@ -93,16 +102,23 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, width: Math.ceil(rect.width), height: Math.ceil(rect.height) }; - }''') - await page.setViewport({'width': clip['x'] + clip['width'], 'height': clip['y'] + clip['height'], 'deviceScaleFactor': device_scale_factor}) - screenshot = await page.screenshot(clip=clip, omit_background=True, scale='device') + }""" + ) + await page.setViewport( + { + "width": clip["x"] + clip["width"], + "height": clip["y"] + clip["height"], + "deviceScaleFactor": device_scale_factor, + } + ) + screenshot = await page.screenshot(clip=clip, omit_background=True, scale="device") logger.info(f"Generating {output_file_without_suffix}.png..") - with open(f'{output_file_without_suffix}.png', 'wb') as f: + with open(f"{output_file_without_suffix}.png", "wb") as f: f.write(screenshot) - if 'pdf' in suffixes: + if "pdf" in suffixes: pdf_data = await page.pdf(scale=device_scale_factor) logger.info(f"Generating {output_file_without_suffix}.pdf..") - with open(f'{output_file_without_suffix}.pdf', 'wb') as f: + with open(f"{output_file_without_suffix}.pdf", "wb") as f: f.write(pdf_data) return 0 except Exception as e: @@ -110,4 +126,3 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, return -1 finally: await browser.close() - diff --git a/metagpt/utils/parse_html.py b/metagpt/utils/parse_html.py index 62de26541..f2395026f 100644 --- a/metagpt/utils/parse_html.py +++ b/metagpt/utils/parse_html.py @@ -16,7 +16,7 @@ class WebPage(BaseModel): class Config: underscore_attrs_are_private = True - _soup : Optional[BeautifulSoup] = None + _soup: Optional[BeautifulSoup] = None _title: Optional[str] = None @property @@ -24,7 +24,7 @@ class WebPage(BaseModel): if self._soup is None: self._soup = BeautifulSoup(self.html, "html.parser") return self._soup - + @property def title(self): if self._title is None: diff --git a/metagpt/utils/pycst.py b/metagpt/utils/pycst.py index afd85a547..1edfed81c 100644 --- a/metagpt/utils/pycst.py +++ b/metagpt/utils/pycst.py @@ -37,12 +37,12 @@ def get_docstring_statement(body: DocstringNode) -> cst.SimpleStatementLine: if not isinstance(expr, cst.Expr): return None - + val = expr.value if not isinstance(val, (cst.SimpleString, cst.ConcatenatedString)): return None - - evaluated_value = val.evaluated_value + + evaluated_value = val.evaluated_value if isinstance(evaluated_value, bytes): return None @@ -56,6 +56,7 @@ class DocstringCollector(cst.CSTVisitor): stack: A list to keep track of the current path in the CST. docstrings: A dictionary mapping paths in the CST to their corresponding docstrings. """ + def __init__(self): self.stack: list[str] = [] self.docstrings: dict[tuple[str, ...], cst.SimpleStatementLine] = {} @@ -96,6 +97,7 @@ class DocstringTransformer(cst.CSTTransformer): stack: A list to keep track of the current path in the CST. docstrings: A dictionary mapping paths in the CST to their corresponding docstrings. """ + def __init__( self, docstrings: dict[tuple[str, ...], cst.SimpleStatementLine], @@ -125,7 +127,9 @@ class DocstringTransformer(cst.CSTTransformer): key = tuple(self.stack) self.stack.pop() - if hasattr(updated_node, "decorators") and any((i.decorator.value == "overload") for i in updated_node.decorators): + if hasattr(updated_node, "decorators") and any( + (i.decorator.value == "overload") for i in updated_node.decorators + ): return updated_node statement = self.docstrings.get(key) diff --git a/metagpt/utils/read_document.py b/metagpt/utils/read_document.py index c837baf25..d2fafbc17 100644 --- a/metagpt/utils/read_document.py +++ b/metagpt/utils/read_document.py @@ -8,6 +8,7 @@ import docx + def read_docx(file_path: str) -> list: """Open a docx file""" doc = docx.Document(file_path) diff --git a/metagpt/utils/singleton.py b/metagpt/utils/singleton.py index 474b537db..a9e0862c0 100644 --- a/metagpt/utils/singleton.py +++ b/metagpt/utils/singleton.py @@ -20,4 +20,3 @@ class Singleton(abc.ABCMeta, type): if cls not in cls._instances: cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) return cls._instances[cls] - \ No newline at end of file diff --git a/metagpt/utils/special_tokens.py b/metagpt/utils/special_tokens.py index 2adb93c77..5e780ce05 100644 --- a/metagpt/utils/special_tokens.py +++ b/metagpt/utils/special_tokens.py @@ -1,4 +1,4 @@ # token to separate different code messages in a WriteCode Message content -MSG_SEP = "#*000*#" +MSG_SEP = "#*000*#" # token to seperate file name and the actual code text in a code message FILENAME_CODE_SEP = "#*001*#" diff --git a/metagpt/utils/text.py b/metagpt/utils/text.py index be3c52edd..dd9678438 100644 --- a/metagpt/utils/text.py +++ b/metagpt/utils/text.py @@ -3,7 +3,12 @@ from typing import Generator, Sequence from metagpt.utils.token_counter import TOKEN_MAX, count_string_tokens -def reduce_message_length(msgs: Generator[str, None, None], model_name: str, system_text: str, reserved: int = 0,) -> str: +def reduce_message_length( + msgs: Generator[str, None, None], + model_name: str, + system_text: str, + reserved: int = 0, +) -> str: """Reduce the length of concatenated message segments to fit within the maximum token size. Args: @@ -49,9 +54,9 @@ def generate_prompt_chunk( current_token = 0 current_lines = [] - reserved = reserved + count_string_tokens(prompt_template+system_text, model_name) + reserved = reserved + count_string_tokens(prompt_template + system_text, model_name) # 100 is a magic number to ensure the maximum context length is not exceeded - max_token = TOKEN_MAX.get(model_name, 2048) - reserved - 100 + max_token = TOKEN_MAX.get(model_name, 2048) - reserved - 100 while paragraphs: paragraph = paragraphs.pop(0) @@ -103,7 +108,7 @@ def decode_unicode_escape(text: str) -> str: return text.encode("utf-8").decode("unicode_escape", "ignore") -def _split_by_count(lst: Sequence , count: int): +def _split_by_count(lst: Sequence, count: int): avg = len(lst) // count remainder = len(lst) % count start = 0 diff --git a/requirements.txt b/requirements.txt index 093298775..24a2d94c3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -44,3 +44,4 @@ ta==0.10.2 semantic-kernel==0.3.13.dev0 wrapt==1.15.0 websocket-client==0.58.0 + diff --git a/tests/conftest.py b/tests/conftest.py index feecc7715..d2ac8304f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,14 +6,14 @@ @File : conftest.py """ +import asyncio +import re from unittest.mock import Mock import pytest from metagpt.logs import logger from metagpt.provider.openai_api import OpenAIGPTAPI as GPTAPI -import asyncio -import re class Context: diff --git a/tests/metagpt/actions/mock.py b/tests/metagpt/actions/mock.py index a800690e8..23d10ccc4 100644 --- a/tests/metagpt/actions/mock.py +++ b/tests/metagpt/actions/mock.py @@ -311,12 +311,10 @@ TASKS = [ "添加数据API:接受用户输入的文档库,对文档库进行索引\n- 使用MeiliSearch连接并添加文档库", "搜索API:接收用户输入的关键词,返回相关的搜索结果\n- 使用MeiliSearch连接并使用接口获得对应数据", "多条件筛选API:接收用户选择的筛选条件,返回符合条件的搜索结果。\n- 使用MeiliSearch进行筛选并返回符合条件的搜索结果", - "智能推荐API:根据用户的搜索历史记录和搜索行为,推荐相关的搜索结果。" + "智能推荐API:根据用户的搜索历史记录和搜索行为,推荐相关的搜索结果。", ] -TASKS_2 = [ - "完成main.py的功能" -] +TASKS_2 = ["完成main.py的功能"] SEARCH_CODE_SAMPLE = """ import requests @@ -460,7 +458,7 @@ if __name__ == '__main__': print('No results found.') ''' -MEILI_CODE = '''import meilisearch +MEILI_CODE = """import meilisearch from typing import List @@ -496,9 +494,9 @@ if __name__ == '__main__': # 添加文档库到搜索引擎 search_engine.add_documents(books_data_source, documents) -''' +""" -MEILI_ERROR = '''/usr/local/bin/python3.9 /Users/alexanderwu/git/metagpt/examples/search/meilisearch_index.py +MEILI_ERROR = """/usr/local/bin/python3.9 /Users/alexanderwu/git/metagpt/examples/search/meilisearch_index.py Traceback (most recent call last): File "/Users/alexanderwu/git/metagpt/examples/search/meilisearch_index.py", line 44, in search_engine.add_documents(books_data_source, documents) @@ -506,7 +504,7 @@ Traceback (most recent call last): index = self.client.get_or_create_index(index_name) AttributeError: 'Client' object has no attribute 'get_or_create_index' -Process finished with exit code 1''' +Process finished with exit code 1""" MEILI_CODE_REFINED = """ """ diff --git a/tests/metagpt/actions/test_action_output.py b/tests/metagpt/actions/test_action_output.py index a556789db..ef8e239bd 100644 --- a/tests/metagpt/actions/test_action_output.py +++ b/tests/metagpt/actions/test_action_output.py @@ -9,18 +9,21 @@ from typing import List, Tuple from metagpt.actions import ActionOutput -t_dict = {"Required Python third-party packages": "\"\"\"\nflask==1.1.2\npygame==2.0.1\n\"\"\"\n", - "Required Other language third-party packages": "\"\"\"\nNo third-party packages required for other languages.\n\"\"\"\n", - "Full API spec": "\"\"\"\nopenapi: 3.0.0\ninfo:\n title: Web Snake Game API\n version: 1.0.0\npaths:\n /game:\n get:\n summary: Get the current game state\n responses:\n '200':\n description: A JSON object of the game state\n post:\n summary: Send a command to the game\n requestBody:\n required: true\n content:\n application/json:\n schema:\n type: object\n properties:\n command:\n type: string\n responses:\n '200':\n description: A JSON object of the updated game state\n\"\"\"\n", - "Logic Analysis": [ - ["app.py", "Main entry point for the Flask application. Handles HTTP requests and responses."], - ["game.py", "Contains the Game and Snake classes. Handles the game logic."], - ["static/js/script.js", "Handles user interactions and updates the game UI."], - ["static/css/styles.css", "Defines the styles for the game UI."], - ["templates/index.html", "The main page of the web application. Displays the game UI."]], - "Task list": ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"], - "Shared Knowledge": "\"\"\"\n'game.py' contains the Game and Snake classes which are responsible for the game logic. The Game class uses an instance of the Snake class.\n\n'app.py' is the main entry point for the Flask application. It creates an instance of the Game class and handles HTTP requests and responses.\n\n'static/js/script.js' is responsible for handling user interactions and updating the game UI based on the game state returned by 'app.py'.\n\n'static/css/styles.css' defines the styles for the game UI.\n\n'templates/index.html' is the main page of the web application. It displays the game UI and loads 'static/js/script.js' and 'static/css/styles.css'.\n\"\"\"\n", - "Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?"} +t_dict = { + "Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n', + "Required Other language third-party packages": '"""\nNo third-party packages required for other languages.\n"""\n', + "Full API spec": '"""\nopenapi: 3.0.0\ninfo:\n title: Web Snake Game API\n version: 1.0.0\npaths:\n /game:\n get:\n summary: Get the current game state\n responses:\n \'200\':\n description: A JSON object of the game state\n post:\n summary: Send a command to the game\n requestBody:\n required: true\n content:\n application/json:\n schema:\n type: object\n properties:\n command:\n type: string\n responses:\n \'200\':\n description: A JSON object of the updated game state\n"""\n', + "Logic Analysis": [ + ["app.py", "Main entry point for the Flask application. Handles HTTP requests and responses."], + ["game.py", "Contains the Game and Snake classes. Handles the game logic."], + ["static/js/script.js", "Handles user interactions and updates the game UI."], + ["static/css/styles.css", "Defines the styles for the game UI."], + ["templates/index.html", "The main page of the web application. Displays the game UI."], + ], + "Task list": ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"], + "Shared Knowledge": "\"\"\"\n'game.py' contains the Game and Snake classes which are responsible for the game logic. The Game class uses an instance of the Snake class.\n\n'app.py' is the main entry point for the Flask application. It creates an instance of the Game class and handles HTTP requests and responses.\n\n'static/js/script.js' is responsible for handling user interactions and updating the game UI based on the game state returned by 'app.py'.\n\n'static/css/styles.css' defines the styles for the game UI.\n\n'templates/index.html' is the main page of the web application. It displays the game UI and loads 'static/js/script.js' and 'static/css/styles.css'.\n\"\"\"\n", + "Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?", +} WRITE_TASKS_OUTPUT_MAPPING = { "Required Python third-party packages": (str, ...), @@ -45,6 +48,6 @@ def test_create_model_class_with_mapping(): assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"] -if __name__ == '__main__': +if __name__ == "__main__": test_create_model_class() test_create_model_class_with_mapping() diff --git a/tests/metagpt/actions/test_azure_tts.py b/tests/metagpt/actions/test_azure_tts.py index b5a333af2..bcafe10f5 100644 --- a/tests/metagpt/actions/test_azure_tts.py +++ b/tests/metagpt/actions/test_azure_tts.py @@ -10,12 +10,7 @@ from metagpt.actions.azure_tts import AzureTTS def test_azure_tts(): azure_tts = AzureTTS("azure_tts") - azure_tts.synthesize_speech( - "zh-CN", - "zh-CN-YunxiNeural", - "Boy", - "你好,我是卡卡", - "output.wav") + azure_tts.synthesize_speech("zh-CN", "zh-CN-YunxiNeural", "Boy", "你好,我是卡卡", "output.wav") # 运行需要先配置 SUBSCRIPTION_KEY # TODO: 这里如果要检验,还要额外加上对应的asr,才能确保前后生成是接近一致的,但现在还没有 diff --git a/tests/metagpt/actions/test_clone_function.py b/tests/metagpt/actions/test_clone_function.py index 6d4432dcd..44248eb80 100644 --- a/tests/metagpt/actions/test_clone_function.py +++ b/tests/metagpt/actions/test_clone_function.py @@ -2,7 +2,6 @@ import pytest from metagpt.actions.clone_function import CloneFunction, run_function_code - source_code = """ import pandas as pd import ta @@ -31,14 +30,18 @@ def get_expected_res(): import ta # 读取股票数据 - stock_data = pd.read_csv('./tests/data/baba_stock.csv') + stock_data = pd.read_csv("./tests/data/baba_stock.csv") stock_data.head() # 计算简单移动平均线 - stock_data['SMA'] = ta.trend.sma_indicator(stock_data['Close'], window=6) - stock_data[['Date', 'Close', 'SMA']].head() + stock_data["SMA"] = ta.trend.sma_indicator(stock_data["Close"], window=6) + stock_data[["Date", "Close", "SMA"]].head() # 计算布林带 - stock_data['bb_upper'], stock_data['bb_middle'], stock_data['bb_lower'] = ta.volatility.bollinger_hband_indicator(stock_data['Close'], window=20), ta.volatility.bollinger_mavg(stock_data['Close'], window=20), ta.volatility.bollinger_lband_indicator(stock_data['Close'], window=20) - stock_data[['Date', 'Close', 'bb_upper', 'bb_middle', 'bb_lower']].head() + stock_data["bb_upper"], stock_data["bb_middle"], stock_data["bb_lower"] = ( + ta.volatility.bollinger_hband_indicator(stock_data["Close"], window=20), + ta.volatility.bollinger_mavg(stock_data["Close"], window=20), + ta.volatility.bollinger_lband_indicator(stock_data["Close"], window=20), + ) + stock_data[["Date", "Close", "bb_upper", "bb_middle", "bb_lower"]].head() return stock_data @@ -46,9 +49,9 @@ def get_expected_res(): async def test_clone_function(): clone = CloneFunction() code = await clone.run(template_code, source_code) - assert 'def ' in code - stock_path = './tests/data/baba_stock.csv' - df, msg = run_function_code(code, 'stock_indicator', stock_path) + assert "def " in code + stock_path = "./tests/data/baba_stock.csv" + df, msg = run_function_code(code, "stock_indicator", stock_path) assert not msg expected_df = get_expected_res() assert df.equals(expected_df) diff --git a/tests/metagpt/actions/test_debug_error.py b/tests/metagpt/actions/test_debug_error.py index 555c84e4e..2393d2cc9 100644 --- a/tests/metagpt/actions/test_debug_error.py +++ b/tests/metagpt/actions/test_debug_error.py @@ -144,12 +144,12 @@ Engineer --- ''' + @pytest.mark.asyncio async def test_debug_error(): - debug_error = DebugError("debug_error") file_name, rewritten_code = await debug_error.run(context=EXAMPLE_MSG_CONTENT) - assert "class Player" in rewritten_code # rewrite the same class - assert "while self.score > 21" in rewritten_code # a key logic to rewrite to (original one is "if self.score > 12") + assert "class Player" in rewritten_code # rewrite the same class + assert "while self.score > 21" in rewritten_code # a key logic to rewrite to (original one is "if self.score > 12") diff --git a/tests/metagpt/actions/test_detail_mining.py b/tests/metagpt/actions/test_detail_mining.py index c9d5331f9..891dca6ca 100644 --- a/tests/metagpt/actions/test_detail_mining.py +++ b/tests/metagpt/actions/test_detail_mining.py @@ -10,6 +10,7 @@ import pytest from metagpt.actions.detail_mining import DetailMining from metagpt.logs import logger + @pytest.mark.asyncio async def test_detail_mining(): topic = "如何做一个生日蛋糕" @@ -17,7 +18,6 @@ async def test_detail_mining(): detail_mining = DetailMining("detail_mining") rsp = await detail_mining.run(topic=topic, record=record) logger.info(f"{rsp.content=}") - - assert '##OUTPUT' in rsp.content - assert '蛋糕' in rsp.content + assert "##OUTPUT" in rsp.content + assert "蛋糕" in rsp.content diff --git a/tests/metagpt/actions/test_invoice_ocr.py b/tests/metagpt/actions/test_invoice_ocr.py index a15166f7c..7f16aa9a4 100644 --- a/tests/metagpt/actions/test_invoice_ocr.py +++ b/tests/metagpt/actions/test_invoice_ocr.py @@ -8,12 +8,11 @@ """ import os -from typing import List - -import pytest from pathlib import Path -from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion +import pytest + +from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion @pytest.mark.asyncio @@ -22,7 +21,7 @@ from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion [ "../../data/invoices/invoice-3.jpg", "../../data/invoices/invoice-4.zip", - ] + ], ) async def test_invoice_ocr(invoice_path: str): invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) @@ -35,18 +34,8 @@ async def test_invoice_ocr(invoice_path: str): @pytest.mark.parametrize( ("invoice_path", "expected_result"), [ - ( - "../../data/invoices/invoice-1.pdf", - [ - { - "收款人": "小明", - "城市": "深圳市", - "总费用/元": "412.00", - "开票日期": "2023年02月03日" - } - ] - ), - ] + ("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]), + ], ) async def test_generate_table(invoice_path: str, expected_result: list[dict]): invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) @@ -59,9 +48,7 @@ async def test_generate_table(invoice_path: str, expected_result: list[dict]): @pytest.mark.asyncio @pytest.mark.parametrize( ("invoice_path", "query", "expected_result"), - [ - ("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日") - ] + [("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")], ) async def test_reply_question(invoice_path: str, query: dict, expected_result: str): invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) @@ -69,4 +56,3 @@ async def test_reply_question(invoice_path: str, query: dict, expected_result: s ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) result = await ReplyQuestion().run(query=query, ocr_result=ocr_result) assert expected_result in result - diff --git a/tests/metagpt/actions/test_ui_design.py b/tests/metagpt/actions/test_ui_design.py index d284b20f2..b8be914ae 100644 --- a/tests/metagpt/actions/test_ui_design.py +++ b/tests/metagpt/actions/test_ui_design.py @@ -4,7 +4,7 @@ # from tests.metagpt.roles.ui_role import UIDesign -llm_resp= ''' +llm_resp = """ # UI Design Description ```The user interface for the snake game will be designed in a way that is simple, clean, and intuitive. The main elements of the game such as the game grid, snake, food, score, and game over message will be clearly defined and easy to understand. The game grid will be centered on the screen with the score displayed at the top. The game controls will be intuitive and easy to use. The design will be modern and minimalist with a pleasing color scheme.``` @@ -98,12 +98,13 @@ body { left: 50%; transform: translate(-50%, -50%); font-size: 3em; - ''' + """ + def test_ui_design_parse_css(): ui_design_work = UIDesign(name="UI design action") - css = ''' + css = """ body { display: flex; flex-direction: column; @@ -160,14 +161,14 @@ def test_ui_design_parse_css(): left: 50%; transform: translate(-50%, -50%); font-size: 3em; - ''' - assert ui_design_work.parse_css_code(context=llm_resp)==css + """ + assert ui_design_work.parse_css_code(context=llm_resp) == css def test_ui_design_parse_html(): ui_design_work = UIDesign(name="UI design action") - html = ''' + html = """ @@ -184,8 +185,5 @@ def test_ui_design_parse_html():
Game Over
- ''' - assert ui_design_work.parse_css_code(context=llm_resp)==html - - - + """ + assert ui_design_work.parse_css_code(context=llm_resp) == html diff --git a/tests/metagpt/actions/test_write_code.py b/tests/metagpt/actions/test_write_code.py index 7bb18ddf2..eb5e3de91 100644 --- a/tests/metagpt/actions/test_write_code.py +++ b/tests/metagpt/actions/test_write_code.py @@ -22,13 +22,13 @@ async def test_write_code(): logger.info(code) # 我们不能精确地预测生成的代码,但我们可以检查某些关键字 - assert 'def add' in code - assert 'return' in code + assert "def add" in code + assert "return" in code @pytest.mark.asyncio async def test_write_code_directly(): - prompt = WRITE_CODE_PROMPT_SAMPLE + '\n' + TASKS_2[0] + prompt = WRITE_CODE_PROMPT_SAMPLE + "\n" + TASKS_2[0] llm = LLM() rsp = await llm.aask(prompt) logger.info(rsp) diff --git a/tests/metagpt/actions/test_write_docstring.py b/tests/metagpt/actions/test_write_docstring.py index 82d96e1a6..a8a80b36d 100644 --- a/tests/metagpt/actions/test_write_docstring.py +++ b/tests/metagpt/actions/test_write_docstring.py @@ -2,7 +2,7 @@ import pytest from metagpt.actions.write_docstring import WriteDocstring -code = ''' +code = """ def add_numbers(a: int, b: int): return a + b @@ -14,7 +14,7 @@ class Person: def greet(self): return f"Hello, my name is {self.name} and I am {self.age} years old." -''' +""" @pytest.mark.asyncio @@ -25,7 +25,7 @@ class Person: ("numpy", "Parameters"), ("sphinx", ":param name:"), ], - ids=["google", "numpy", "sphinx"] + ids=["google", "numpy", "sphinx"], ) async def test_write_docstring(style: str, part: str): ret = await WriteDocstring().run(code, style=style) diff --git a/tests/metagpt/actions/test_write_tutorial.py b/tests/metagpt/actions/test_write_tutorial.py index 683fee082..27a323b44 100644 --- a/tests/metagpt/actions/test_write_tutorial.py +++ b/tests/metagpt/actions/test_write_tutorial.py @@ -9,14 +9,11 @@ from typing import Dict import pytest -from metagpt.actions.write_tutorial import WriteDirectory, WriteContent +from metagpt.actions.write_tutorial import WriteContent, WriteDirectory @pytest.mark.asyncio -@pytest.mark.parametrize( - ("language", "topic"), - [("English", "Write a tutorial about Python")] -) +@pytest.mark.parametrize(("language", "topic"), [("English", "Write a tutorial about Python")]) async def test_write_directory(language: str, topic: str): ret = await WriteDirectory(language=language).run(topic=topic) assert isinstance(ret, dict) @@ -30,7 +27,7 @@ async def test_write_directory(language: str, topic: str): @pytest.mark.asyncio @pytest.mark.parametrize( ("language", "topic", "directory"), - [("English", "Write a tutorial about Python", {"Introduction": ["What is Python?", "Why learn Python?"]})] + [("English", "Write a tutorial about Python", {"Introduction": ["What is Python?", "Why learn Python?"]})], ) async def test_write_content(language: str, topic: str, directory: Dict): ret = await WriteContent(language=language, directory=directory).run(topic=topic) diff --git a/tests/metagpt/document_store/test_chromadb_store.py b/tests/metagpt/document_store/test_chromadb_store.py index f8c11e1ca..fd115dcdd 100644 --- a/tests/metagpt/document_store/test_chromadb_store.py +++ b/tests/metagpt/document_store/test_chromadb_store.py @@ -12,12 +12,12 @@ from metagpt.document_store.chromadb_store import ChromaStore def test_chroma_store(): """FIXME:chroma使用感觉很诡异,一用Python就挂,测试用例里也是""" # 创建 ChromaStore 实例,使用 'sample_collection' 集合 - document_store = ChromaStore('sample_collection_1') + document_store = ChromaStore("sample_collection_1") # 使用 write 方法添加多个文档 - document_store.write(["This is document1", "This is document2"], - [{"source": "google-docs"}, {"source": "notion"}], - ["doc1", "doc2"]) + document_store.write( + ["This is document1", "This is document2"], [{"source": "google-docs"}, {"source": "notion"}], ["doc1", "doc2"] + ) # 使用 add 方法添加一个文档 document_store.add("This is document3", {"source": "notion"}, "doc3") diff --git a/tests/metagpt/document_store/test_faiss_store.py b/tests/metagpt/document_store/test_faiss_store.py index d22d234f5..f14bee817 100644 --- a/tests/metagpt/document_store/test_faiss_store.py +++ b/tests/metagpt/document_store/test_faiss_store.py @@ -39,11 +39,11 @@ user: 没有了 @pytest.mark.asyncio async def test_faiss_store_search(): - store = FaissStore(DATA_PATH / 'qcs/qcs_4w.json') - store.add(['油皮洗面奶']) + store = FaissStore(DATA_PATH / "qcs/qcs_4w.json") + store.add(["油皮洗面奶"]) role = Sales(store=store) - queries = ['油皮洗面奶', '介绍下欧莱雅的'] + queries = ["油皮洗面奶", "介绍下欧莱雅的"] for query in queries: rsp = await role.run(query) assert rsp @@ -60,7 +60,10 @@ def customer_service(): async def test_faiss_store_customer_service(): allq = [ # ["我的餐怎么两小时都没到", "退货吧"], - ["你好收不到取餐码,麻烦帮我开箱", "14750187158", ] + [ + "你好收不到取餐码,麻烦帮我开箱", + "14750187158", + ] ] role = customer_service() for queries in allq: @@ -71,4 +74,4 @@ async def test_faiss_store_customer_service(): def test_faiss_store_no_file(): with pytest.raises(FileNotFoundError): - FaissStore(DATA_PATH / 'wtf.json') + FaissStore(DATA_PATH / "wtf.json") diff --git a/tests/metagpt/document_store/test_lancedb_store.py b/tests/metagpt/document_store/test_lancedb_store.py index 9c2f9fb42..5c0e40f57 100644 --- a/tests/metagpt/document_store/test_lancedb_store.py +++ b/tests/metagpt/document_store/test_lancedb_store.py @@ -5,27 +5,33 @@ @Author : unkn-wn (Leon Yee) @File : test_lancedb_store.py """ -from metagpt.document_store.lancedb_store import LanceStore -import pytest import random +import pytest + +from metagpt.document_store.lancedb_store import LanceStore + + @pytest def test_lance_store(): - # This simply establishes the connection to the database, so we can drop the table if it exists - store = LanceStore('test') + store = LanceStore("test") - store.drop('test') + store.drop("test") - store.write(data=[[random.random() for _ in range(100)] for _ in range(2)], - metadatas=[{"source": "google-docs"}, {"source": "notion"}], - ids=["doc1", "doc2"]) + store.write( + data=[[random.random() for _ in range(100)] for _ in range(2)], + metadatas=[{"source": "google-docs"}, {"source": "notion"}], + ids=["doc1", "doc2"], + ) store.add(data=[random.random() for _ in range(100)], metadata={"source": "notion"}, _id="doc3") result = store.search([random.random() for _ in range(100)], n_results=3) - assert(len(result) == 3) + assert len(result) == 3 store.delete("doc2") - result = store.search([random.random() for _ in range(100)], n_results=3, where="source = 'notion'", metric='cosine') - assert(len(result) == 1) \ No newline at end of file + result = store.search( + [random.random() for _ in range(100)], n_results=3, where="source = 'notion'", metric="cosine" + ) + assert len(result) == 1 diff --git a/tests/metagpt/document_store/test_milvus_store.py b/tests/metagpt/document_store/test_milvus_store.py index 1cf65776d..34497b9c6 100644 --- a/tests/metagpt/document_store/test_milvus_store.py +++ b/tests/metagpt/document_store/test_milvus_store.py @@ -12,7 +12,7 @@ import numpy as np from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore from metagpt.logs import logger -book_columns = {'idx': int, 'name': str, 'desc': str, 'emb': np.ndarray, 'price': float} +book_columns = {"idx": int, "name": str, "desc": str, "emb": np.ndarray, "price": float} book_data = [ [i for i in range(10)], [f"book-{i}" for i in range(10)], @@ -25,12 +25,12 @@ book_data = [ def test_milvus_store(): milvus_connection = MilvusConnection(alias="default", host="192.168.50.161", port="30530") milvus_store = MilvusStore(milvus_connection) - milvus_store.drop('Book') - milvus_store.create_collection('Book', book_columns) + milvus_store.drop("Book") + milvus_store.create_collection("Book", book_columns) milvus_store.add(book_data) - milvus_store.build_index('emb') + milvus_store.build_index("emb") milvus_store.load_collection() - results = milvus_store.search([[1.0, 1.0]], field='emb') + results = milvus_store.search([[1.0, 1.0]], field="emb") logger.info(results) assert results diff --git a/tests/metagpt/document_store/test_qdrant_store.py b/tests/metagpt/document_store/test_qdrant_store.py index a63a4329d..cdd619d37 100644 --- a/tests/metagpt/document_store/test_qdrant_store.py +++ b/tests/metagpt/document_store/test_qdrant_store.py @@ -24,9 +24,7 @@ random.seed(seed_value) vectors = [[random.random() for _ in range(2)] for _ in range(10)] points = [ - PointStruct( - id=idx, vector=vector, payload={"color": "red", "rand_number": idx % 10} - ) + PointStruct(id=idx, vector=vector, payload={"color": "red", "rand_number": idx % 10}) for idx, vector in enumerate(vectors) ] @@ -57,9 +55,7 @@ def test_milvus_store(): results = qdrant_store.search( "Book", query=[1.0, 1.0], - query_filter=Filter( - must=[FieldCondition(key="rand_number", range=Range(gte=8))] - ), + query_filter=Filter(must=[FieldCondition(key="rand_number", range=Range(gte=8))]), ) assert results[0]["id"] == 8 assert results[0]["score"] == 0.9100373450784073 @@ -68,9 +64,7 @@ def test_milvus_store(): results = qdrant_store.search( "Book", query=[1.0, 1.0], - query_filter=Filter( - must=[FieldCondition(key="rand_number", range=Range(gte=8))] - ), + query_filter=Filter(must=[FieldCondition(key="rand_number", range=Range(gte=8))]), return_vector=True, ) assert results[0]["vector"] == [0.35037919878959656, 0.9366079568862915] diff --git a/tests/metagpt/management/test_skill_manager.py b/tests/metagpt/management/test_skill_manager.py index b0be858a1..462bc23a6 100644 --- a/tests/metagpt/management/test_skill_manager.py +++ b/tests/metagpt/management/test_skill_manager.py @@ -30,7 +30,7 @@ def test_skill_manager(): rsp = manager.retrieve_skill("写测试用例") logger.info(rsp) - assert rsp[0] == 'WriteTest' + assert rsp[0] == "WriteTest" rsp = manager.retrieve_skill_scored("写PRD") logger.info(rsp) diff --git a/tests/metagpt/memory/test_longterm_memory.py b/tests/metagpt/memory/test_longterm_memory.py index dc5540520..9682ba760 100644 --- a/tests/metagpt/memory/test_longterm_memory.py +++ b/tests/metagpt/memory/test_longterm_memory.py @@ -2,11 +2,11 @@ # -*- coding: utf-8 -*- # @Desc : unittest of `metagpt/memory/longterm_memory.py` -from metagpt.config import CONFIG -from metagpt.schema import Message from metagpt.actions import BossRequirement -from metagpt.roles.role import RoleContext +from metagpt.config import CONFIG from metagpt.memory import LongTermMemory +from metagpt.roles.role import RoleContext +from metagpt.schema import Message def test_ltm_search(): @@ -14,25 +14,25 @@ def test_ltm_search(): openai_api_key = CONFIG.openai_api_key assert len(openai_api_key) > 20 - role_id = 'UTUserLtm(Product Manager)' + role_id = "UTUserLtm(Product Manager)" rc = RoleContext(watch=[BossRequirement]) ltm = LongTermMemory() ltm.recover_memory(role_id, rc) - idea = 'Write a cli snake game' - message = Message(role='BOSS', content=idea, cause_by=BossRequirement) + idea = "Write a cli snake game" + message = Message(role="BOSS", content=idea, cause_by=BossRequirement) news = ltm.find_news([message]) assert len(news) == 1 ltm.add(message) - sim_idea = 'Write a game of cli snake' - sim_message = Message(role='BOSS', content=sim_idea, cause_by=BossRequirement) + sim_idea = "Write a game of cli snake" + sim_message = Message(role="BOSS", content=sim_idea, cause_by=BossRequirement) news = ltm.find_news([sim_message]) assert len(news) == 0 ltm.add(sim_message) - new_idea = 'Write a 2048 web game' - new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement) + new_idea = "Write a 2048 web game" + new_message = Message(role="BOSS", content=new_idea, cause_by=BossRequirement) news = ltm.find_news([new_message]) assert len(news) == 1 ltm.add(new_message) @@ -47,8 +47,8 @@ def test_ltm_search(): news = ltm_new.find_news([sim_message]) assert len(news) == 0 - new_idea = 'Write a Battle City' - new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement) + new_idea = "Write a Battle City" + new_message = Message(role="BOSS", content=new_idea, cause_by=BossRequirement) news = ltm_new.find_news([new_message]) assert len(news) == 1 diff --git a/tests/metagpt/memory/test_memory_storage.py b/tests/metagpt/memory/test_memory_storage.py index 6bb3e8f1d..8b338a79e 100644 --- a/tests/metagpt/memory/test_memory_storage.py +++ b/tests/metagpt/memory/test_memory_storage.py @@ -4,17 +4,16 @@ from typing import List +from metagpt.actions import BossRequirement, WritePRD +from metagpt.actions.action_output import ActionOutput from metagpt.memory.memory_storage import MemoryStorage from metagpt.schema import Message -from metagpt.actions import BossRequirement -from metagpt.actions import WritePRD -from metagpt.actions.action_output import ActionOutput def test_idea_message(): - idea = 'Write a cli snake game' - role_id = 'UTUser1(Product Manager)' - message = Message(role='BOSS', content=idea, cause_by=BossRequirement) + idea = "Write a cli snake game" + role_id = "UTUser1(Product Manager)" + message = Message(role="BOSS", content=idea, cause_by=BossRequirement) memory_storage: MemoryStorage = MemoryStorage() messages = memory_storage.recover_memory(role_id) @@ -23,13 +22,13 @@ def test_idea_message(): memory_storage.add(message) assert memory_storage.is_initialized is True - sim_idea = 'Write a game of cli snake' - sim_message = Message(role='BOSS', content=sim_idea, cause_by=BossRequirement) + sim_idea = "Write a game of cli snake" + sim_message = Message(role="BOSS", content=sim_idea, cause_by=BossRequirement) new_messages = memory_storage.search(sim_message) - assert len(new_messages) == 0 # similar, return [] + assert len(new_messages) == 0 # similar, return [] - new_idea = 'Write a 2048 web game' - new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement) + new_idea = "Write a 2048 web game" + new_message = Message(role="BOSS", content=new_idea, cause_by=BossRequirement) new_messages = memory_storage.search(new_message) assert new_messages[0].content == message.content @@ -38,22 +37,15 @@ def test_idea_message(): def test_actionout_message(): - out_mapping = { - 'field1': (str, ...), - 'field2': (List[str], ...) - } - out_data = { - 'field1': 'field1 value', - 'field2': ['field2 value1', 'field2 value2'] - } - ic_obj = ActionOutput.create_model_class('prd', out_mapping) + out_mapping = {"field1": (str, ...), "field2": (List[str], ...)} + out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]} + ic_obj = ActionOutput.create_model_class("prd", out_mapping) - role_id = 'UTUser2(Architect)' - content = 'The boss has requested the creation of a command-line interface (CLI) snake game' - message = Message(content=content, - instruct_content=ic_obj(**out_data), - role='user', - cause_by=WritePRD) # WritePRD as test action + role_id = "UTUser2(Architect)" + content = "The boss has requested the creation of a command-line interface (CLI) snake game" + message = Message( + content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD + ) # WritePRD as test action memory_storage: MemoryStorage = MemoryStorage() messages = memory_storage.recover_memory(role_id) @@ -62,19 +54,13 @@ def test_actionout_message(): memory_storage.add(message) assert memory_storage.is_initialized is True - sim_conent = 'The request is command-line interface (CLI) snake game' - sim_message = Message(content=sim_conent, - instruct_content=ic_obj(**out_data), - role='user', - cause_by=WritePRD) + sim_conent = "The request is command-line interface (CLI) snake game" + sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD) new_messages = memory_storage.search(sim_message) - assert len(new_messages) == 0 # similar, return [] + assert len(new_messages) == 0 # similar, return [] - new_conent = 'Incorporate basic features of a snake game such as scoring and increasing difficulty' - new_message = Message(content=new_conent, - instruct_content=ic_obj(**out_data), - role='user', - cause_by=WritePRD) + new_conent = "Incorporate basic features of a snake game such as scoring and increasing difficulty" + new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD) new_messages = memory_storage.search(new_message) assert new_messages[0].content == message.content diff --git a/tests/metagpt/provider/test_base_gpt_api.py b/tests/metagpt/provider/test_base_gpt_api.py index 882338a01..6cfe3b02d 100644 --- a/tests/metagpt/provider/test_base_gpt_api.py +++ b/tests/metagpt/provider/test_base_gpt_api.py @@ -10,6 +10,6 @@ from metagpt.schema import Message def test_message(): - message = Message(role='user', content='wtf') - assert 'role' in message.to_dict() - assert 'user' in str(message) + message = Message(role="user", content="wtf") + assert "role" in message.to_dict() + assert "user" in str(message) diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index bfa2bf76f..3b3dd67f4 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -6,6 +6,6 @@ def test_message(): llm = SparkAPI() logger.info(llm.ask('只回答"收到了"这三个字。')) - result = llm.ask('写一篇五百字的日记') + result = llm.ask("写一篇五百字的日记") logger.info(result) assert len(result) > 100 diff --git a/tests/metagpt/roles/mock.py b/tests/metagpt/roles/mock.py index 52fc4a3c1..1b02fbaa5 100644 --- a/tests/metagpt/roles/mock.py +++ b/tests/metagpt/roles/mock.py @@ -71,7 +71,7 @@ PRD = '''## 原始需求 ``` ''' -SYSTEM_DESIGN = '''## Python package name +SYSTEM_DESIGN = """## Python package name ```python "smart_search_engine" ``` @@ -149,10 +149,10 @@ sequenceDiagram S-->>SE: return summary SE-->>M: return summary ``` -''' +""" -TASKS = '''## Logic Analysis +TASKS = """## Logic Analysis 在这个项目中,所有的模块都依赖于“SearchEngine”类,这是主入口,其他的模块(Index、Ranking和Summary)都通过它交互。另外,"Index"类又依赖于"KnowledgeBase"类,因为它需要从知识库中获取数据。 @@ -181,7 +181,7 @@ task_list = [ ] ``` 这个任务列表首先定义了最基础的模块,然后是依赖这些模块的模块,最后是辅助模块。可以根据团队的能力和资源,同时开发多个任务,只要满足依赖关系。例如,在开发"search.py"之前,可以同时开发"knowledge_base.py"、"index.py"、"ranking.py"和"summary.py"。 -''' +""" TASKS_TOMATO_CLOCK = '''## Required Python third-party packages: Provided in requirements.txt format @@ -224,30 +224,30 @@ task_list = [ TASK = """smart_search_engine/knowledge_base.py""" STRS_FOR_PARSING = [ -""" + """ ## 1 ```python a ``` """, -""" + """ ##2 ```python "a" ``` """, -""" + """ ## 3 ```python a = "a" ``` """, -""" + """ ## 4 ```python a = 'a' ``` -""" +""", ] diff --git a/tests/metagpt/roles/test_engineer.py b/tests/metagpt/roles/test_engineer.py index c0c48d0b1..f44188c17 100644 --- a/tests/metagpt/roles/test_engineer.py +++ b/tests/metagpt/roles/test_engineer.py @@ -35,13 +35,13 @@ def test_parse_str(): for idx, i in enumerate(STRS_FOR_PARSING): text = CodeParser.parse_str(f"{idx+1}", i) # logger.info(text) - assert text == 'a' + assert text == "a" def test_parse_blocks(): tasks = CodeParser.parse_blocks(TASKS) logger.info(tasks.keys()) - assert 'Task list' in tasks.keys() + assert "Task list" in tasks.keys() target_list = [ diff --git a/tests/metagpt/roles/test_invoice_ocr_assistant.py b/tests/metagpt/roles/test_invoice_ocr_assistant.py index 75097e73c..c9aad93a7 100644 --- a/tests/metagpt/roles/test_invoice_ocr_assistant.py +++ b/tests/metagpt/roles/test_invoice_ocr_assistant.py @@ -9,8 +9,8 @@ from pathlib import Path -import pytest import pandas as pd +import pytest from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant from metagpt.schema import Message @@ -24,82 +24,39 @@ from metagpt.schema import Message "Invoicing date", Path("../../data/invoices/invoice-1.pdf"), Path("../../../data/invoice_table/invoice-1.xlsx"), - [ - { - "收款人": "小明", - "城市": "深圳市", - "总费用/元": 412.00, - "开票日期": "2023年02月03日" - } - ] + [{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}], ), ( "Invoicing date", Path("../../data/invoices/invoice-2.png"), Path("../../../data/invoice_table/invoice-2.xlsx"), - [ - { - "收款人": "铁头", - "城市": "广州市", - "总费用/元": 898.00, - "开票日期": "2023年03月17日" - } - ] + [{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}], ), ( "Invoicing date", Path("../../data/invoices/invoice-3.jpg"), Path("../../../data/invoice_table/invoice-3.xlsx"), - [ - { - "收款人": "夏天", - "城市": "福州市", - "总费用/元": 2462.00, - "开票日期": "2023年08月26日" - } - ] + [{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}], ), ( "Invoicing date", Path("../../data/invoices/invoice-4.zip"), Path("../../../data/invoice_table/invoice-4.xlsx"), [ - { - "收款人": "小明", - "城市": "深圳市", - "总费用/元": 412.00, - "开票日期": "2023年02月03日" - }, - { - "收款人": "铁头", - "城市": "广州市", - "总费用/元": 898.00, - "开票日期": "2023年03月17日" - }, - { - "收款人": "夏天", - "城市": "福州市", - "总费用/元": 2462.00, - "开票日期": "2023年08月26日" - } - ] + {"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}, + {"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}, + {"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}, + ], ), - ] + ], ) async def test_invoice_ocr_assistant( - query: str, - invoice_path: Path, - invoice_table_path: Path, - expected_result: list[dict] + query: str, invoice_path: Path, invoice_table_path: Path, expected_result: list[dict] ): invoice_path = Path.cwd() / invoice_path role = InvoiceOCRAssistant() - await role.run(Message( - content=query, - instruct_content={"file_path": invoice_path} - )) + await role.run(Message(content=query, instruct_content={"file_path": invoice_path})) invoice_table_path = Path.cwd() / invoice_table_path df = pd.read_excel(invoice_table_path) - dict_result = df.to_dict(orient='records') + dict_result = df.to_dict(orient="records") assert dict_result == expected_result - diff --git a/tests/metagpt/roles/test_researcher.py b/tests/metagpt/roles/test_researcher.py index 01b5dae3b..dd130662d 100644 --- a/tests/metagpt/roles/test_researcher.py +++ b/tests/metagpt/roles/test_researcher.py @@ -11,10 +11,12 @@ async def mock_llm_ask(self, prompt: str, system_msgs): if "Please provide up to 2 necessary keywords" in prompt: return '["dataiku", "datarobot"]' elif "Provide up to 4 queries related to your research topic" in prompt: - return '["Dataiku machine learning platform", "DataRobot AI platform comparison", ' \ + return ( + '["Dataiku machine learning platform", "DataRobot AI platform comparison", ' '"Dataiku vs DataRobot features", "Dataiku and DataRobot use cases"]' + ) elif "sort the remaining search results" in prompt: - return '[1,2]' + return "[1,2]" elif "Not relevant." in prompt: return "Not relevant" if random() > 0.5 else prompt[-100:] elif "provide a detailed research report" in prompt: diff --git a/tests/metagpt/roles/test_tutorial_assistant.py b/tests/metagpt/roles/test_tutorial_assistant.py index 945620cfc..105f976c3 100644 --- a/tests/metagpt/roles/test_tutorial_assistant.py +++ b/tests/metagpt/roles/test_tutorial_assistant.py @@ -12,10 +12,7 @@ from metagpt.roles.tutorial_assistant import TutorialAssistant @pytest.mark.asyncio -@pytest.mark.parametrize( - ("language", "topic"), - [("Chinese", "Write a tutorial about Python")] -) +@pytest.mark.parametrize(("language", "topic"), [("Chinese", "Write a tutorial about Python")]) async def test_tutorial_assistant(language: str, topic: str): topic = "Write a tutorial about MySQL" role = TutorialAssistant(language=language) @@ -24,4 +21,4 @@ async def test_tutorial_assistant(language: str, topic: str): title = filename.split("/")[-1].split(".")[0] async with aiofiles.open(filename, mode="r") as reader: content = await reader.read() - assert content.startswith(f"# {title}") \ No newline at end of file + assert content.startswith(f"# {title}") diff --git a/tests/metagpt/roles/test_ui.py b/tests/metagpt/roles/test_ui.py index 285bff323..2d9cb85c9 100644 --- a/tests/metagpt/roles/test_ui.py +++ b/tests/metagpt/roles/test_ui.py @@ -2,9 +2,8 @@ # @Date : 2023/7/22 02:40 # @Author : stellahong (stellahong@fuzhi.ai) # -from metagpt.software_company import SoftwareCompany from metagpt.roles import ProductManager - +from metagpt.software_company import SoftwareCompany from tests.metagpt.roles.ui_role import UI diff --git a/tests/metagpt/test_gpt.py b/tests/metagpt/test_gpt.py index 89dd726a8..285e8134c 100644 --- a/tests/metagpt/test_gpt.py +++ b/tests/metagpt/test_gpt.py @@ -14,7 +14,7 @@ from metagpt.logs import logger @pytest.mark.usefixtures("llm_api") class TestGPT: def test_llm_api_ask(self, llm_api): - answer = llm_api.ask('hello chatgpt') + answer = llm_api.ask("hello chatgpt") assert len(answer) > 0 # def test_gptapi_ask_batch(self, llm_api): @@ -22,22 +22,22 @@ class TestGPT: # assert len(answer) > 0 def test_llm_api_ask_code(self, llm_api): - answer = llm_api.ask_code(['请扮演一个Google Python专家工程师,如果理解,回复明白', '写一个hello world']) + answer = llm_api.ask_code(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"]) assert len(answer) > 0 @pytest.mark.asyncio async def test_llm_api_aask(self, llm_api): - answer = await llm_api.aask('hello chatgpt') + answer = await llm_api.aask("hello chatgpt") assert len(answer) > 0 @pytest.mark.asyncio async def test_llm_api_aask_code(self, llm_api): - answer = await llm_api.aask_code(['请扮演一个Google Python专家工程师,如果理解,回复明白', '写一个hello world']) + answer = await llm_api.aask_code(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"]) assert len(answer) > 0 @pytest.mark.asyncio async def test_llm_api_costs(self, llm_api): - await llm_api.aask('hello chatgpt') + await llm_api.aask("hello chatgpt") costs = llm_api.get_costs() logger.info(costs) assert costs.total_cost > 0 diff --git a/tests/metagpt/test_llm.py b/tests/metagpt/test_llm.py index 11503af1d..03341212b 100644 --- a/tests/metagpt/test_llm.py +++ b/tests/metagpt/test_llm.py @@ -18,17 +18,17 @@ def llm(): @pytest.mark.asyncio async def test_llm_aask(llm): - assert len(await llm.aask('hello world')) > 0 + assert len(await llm.aask("hello world")) > 0 @pytest.mark.asyncio async def test_llm_aask_batch(llm): - assert len(await llm.aask_batch(['hi', 'write python hello world.'])) > 0 + assert len(await llm.aask_batch(["hi", "write python hello world."])) > 0 @pytest.mark.asyncio async def test_llm_acompletion(llm): - hello_msg = [{'role': 'user', 'content': 'hello'}] + hello_msg = [{"role": "user", "content": "hello"}] assert len(await llm.acompletion(hello_msg)) > 0 assert len(await llm.acompletion_batch([hello_msg])) > 0 assert len(await llm.acompletion_batch_text([hello_msg])) > 0 diff --git a/tests/metagpt/test_message.py b/tests/metagpt/test_message.py index e26f38381..ae6708943 100644 --- a/tests/metagpt/test_message.py +++ b/tests/metagpt/test_message.py @@ -11,26 +11,26 @@ from metagpt.schema import AIMessage, Message, RawMessage, SystemMessage, UserMe def test_message(): - msg = Message(role='User', content='WTF') - assert msg.to_dict()['role'] == 'User' - assert 'User' in str(msg) + msg = Message(role="User", content="WTF") + assert msg.to_dict()["role"] == "User" + assert "User" in str(msg) def test_all_messages(): - test_content = 'test_message' + test_content = "test_message" msgs = [ UserMessage(test_content), SystemMessage(test_content), AIMessage(test_content), - Message(test_content, role='QA') + Message(test_content, role="QA"), ] for msg in msgs: assert msg.content == test_content def test_raw_message(): - msg = RawMessage(role='user', content='raw') - assert msg['role'] == 'user' - assert msg['content'] == 'raw' + msg = RawMessage(role="user", content="raw") + assert msg["role"] == "user" + assert msg["content"] == "raw" with pytest.raises(KeyError): - assert msg['1'] == 1, "KeyError: '1'" + assert msg["1"] == 1, "KeyError: '1'" diff --git a/tests/metagpt/test_role.py b/tests/metagpt/test_role.py index 11fd804ec..22cfa58a4 100644 --- a/tests/metagpt/test_role.py +++ b/tests/metagpt/test_role.py @@ -9,6 +9,6 @@ from metagpt.roles import Role def test_role_desc(): - i = Role(profile='Sales', desc='Best Seller') - assert i.profile == 'Sales' - assert i._setting.desc == 'Best Seller' + i = Role(profile="Sales", desc="Best Seller") + assert i.profile == "Sales" + assert i._setting.desc == "Best Seller" diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 12666e0d3..c154d77e1 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -9,13 +9,13 @@ from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage def test_messages(): - test_content = 'test_message' + test_content = "test_message" msgs = [ UserMessage(test_content), SystemMessage(test_content), AIMessage(test_content), - Message(test_content, role='QA') + Message(test_content, role="QA"), ] text = str(msgs) - roles = ['user', 'system', 'assistant', 'QA'] + roles = ["user", "system", "assistant", "QA"] assert all([i in text for i in roles]) diff --git a/tests/metagpt/tools/test_code_interpreter.py b/tests/metagpt/tools/test_code_interpreter.py index 0eec3f80b..03d4ce8df 100644 --- a/tests/metagpt/tools/test_code_interpreter.py +++ b/tests/metagpt/tools/test_code_interpreter.py @@ -1,23 +1,22 @@ -import pytest -import pandas as pd from pathlib import Path -from tests.data import sales_desc, store_desc -from metagpt.tools.code_interpreter import OpenCodeInterpreter, OpenInterpreterDecorator +import pandas as pd +import pytest + from metagpt.actions import Action from metagpt.logs import logger +from metagpt.tools.code_interpreter import OpenCodeInterpreter, OpenInterpreterDecorator - -logger.add('./tests/data/test_ci.log') +logger.add("./tests/data/test_ci.log") stock = "./tests/data/baba_stock.csv" # TODO: 需要一种表格数据格式,能够支持schame管理的,标注字段类型和字段含义。 class CreateStockIndicators(Action): @OpenInterpreterDecorator(save_code=True, code_file_path="./tests/data/stock_indicators.py") - async def run(self, stock_path: str, indicators=['Simple Moving Average', 'BollingerBands']) -> pd.DataFrame: + async def run(self, stock_path: str, indicators=["Simple Moving Average", "BollingerBands"]) -> pd.DataFrame: """对stock_path中的股票数据, 使用pandas和ta计算indicators中的技术指标, 返回带有技术指标的股票数据,不需要去除空值, 不需要安装任何包; - 指标生成对应的三列: SMA, BB_upper, BB_lower + 指标生成对应的三列: SMA, BB_upper, BB_lower """ ... @@ -25,18 +24,20 @@ class CreateStockIndicators(Action): @pytest.mark.asyncio async def test_actions(): # 计算指标 - indicators = ['Simple Moving Average', 'BollingerBands'] + indicators = ["Simple Moving Average", "BollingerBands"] stocker = CreateStockIndicators() df, msg = await stocker.run(stock, indicators=indicators) assert isinstance(df, pd.DataFrame) - assert 'Close' in df.columns - assert 'Date' in df.columns + assert "Close" in df.columns + assert "Date" in df.columns # 将df保存为文件,将文件路径传入到下一个action - df_path = './tests/data/stock_indicators.csv' + df_path = "./tests/data/stock_indicators.csv" df.to_csv(df_path) assert Path(df_path).is_file() # 可视化指标结果 - figure_path = './tests/data/figure_ci.png' + figure_path = "./tests/data/figure_ci.png" ci_ploter = OpenCodeInterpreter() - ci_ploter.chat(f"使用seaborn对{df_path}中与股票布林带有关的数据列的Date, Close, SMA, BB_upper(布林带上界), BB_lower(布林带下界)进行可视化, 可视化图片保存在{figure_path}中。不需要任何指标计算,把Date列转换为日期类型。要求图片优美,BB_upper, BB_lower之间使用合适的颜色填充。") + ci_ploter.chat( + f"使用seaborn对{df_path}中与股票布林带有关的数据列的Date, Close, SMA, BB_upper(布林带上界), BB_lower(布林带下界)进行可视化, 可视化图片保存在{figure_path}中。不需要任何指标计算,把Date列转换为日期类型。要求图片优美,BB_upper, BB_lower之间使用合适的颜色填充。" + ) assert Path(figure_path).is_file() diff --git a/tests/metagpt/tools/test_prompt_generator.py b/tests/metagpt/tools/test_prompt_generator.py index d2e870c6d..ddbd2c43b 100644 --- a/tests/metagpt/tools/test_prompt_generator.py +++ b/tests/metagpt/tools/test_prompt_generator.py @@ -20,8 +20,9 @@ from metagpt.tools.prompt_writer import ( @pytest.mark.usefixtures("llm_api") def test_gpt_prompt_generator(llm_api): generator = GPTPromptGenerator() - example = "商品名称:WonderLab 新肌果味代餐奶昔 小胖瓶 胶原蛋白升级版 饱腹代餐粉6瓶 75g/瓶(6瓶/盒) 店铺名称:金力宁食品专营店 " \ - "品牌:WonderLab 保质期:1年 产地:中国 净含量:450g" + example = ( + "商品名称:WonderLab 新肌果味代餐奶昔 小胖瓶 胶原蛋白升级版 饱腹代餐粉6瓶 75g/瓶(6瓶/盒) 店铺名称:金力宁食品专营店 " "品牌:WonderLab 保质期:1年 产地:中国 净含量:450g" + ) results = llm_api.ask_batch(generator.gen(example)) logger.info(results) @@ -46,7 +47,7 @@ def test_enron_template(llm_api): results = template.gen(subj) assert len(results) > 0 - assert any("Write an email with the subject \"Meeting Agenda\"." in r for r in results) + assert any('Write an email with the subject "Meeting Agenda".' in r for r in results) def test_beagec_template(): @@ -54,5 +55,6 @@ def test_beagec_template(): results = template.gen() assert len(results) > 0 - assert any("Edit and revise this document to improve its grammar, vocabulary, spelling, and style." - in r for r in results) + assert any( + "Edit and revise this document to improve its grammar, vocabulary, spelling, and style." in r for r in results + ) diff --git a/tests/metagpt/tools/test_sd_tool.py b/tests/metagpt/tools/test_sd_tool.py index 77e53c7dc..4edd8fb3b 100644 --- a/tests/metagpt/tools/test_sd_tool.py +++ b/tests/metagpt/tools/test_sd_tool.py @@ -4,7 +4,7 @@ # import os -from metagpt.tools.sd_engine import SDEngine, WORKSPACE_ROOT +from metagpt.tools.sd_engine import WORKSPACE_ROOT, SDEngine def test_sd_engine_init(): diff --git a/tests/metagpt/tools/test_search_engine.py b/tests/metagpt/tools/test_search_engine.py index a7fe063a6..25bce124a 100644 --- a/tests/metagpt/tools/test_search_engine.py +++ b/tests/metagpt/tools/test_search_engine.py @@ -16,7 +16,9 @@ from metagpt.tools.search_engine import SearchEngine class MockSearchEnine: async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]: - rets = [{"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results)] + rets = [ + {"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results) + ] return "\n".join(rets) if as_string else rets @@ -34,10 +36,14 @@ class MockSearchEnine: (SearchEngineType.DUCK_DUCK_GO, None, 6, False), (SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 8, False), (SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False), - ], ) -async def test_search_engine(search_engine_typpe, run_func, max_results, as_string, ): +async def test_search_engine( + search_engine_typpe, + run_func, + max_results, + as_string, +): search_engine = SearchEngine(search_engine_typpe, run_func) rsp = await search_engine.run("metagpt", max_results=max_results, as_string=as_string) logger.info(rsp) diff --git a/tests/metagpt/tools/test_search_engine_meilisearch.py b/tests/metagpt/tools/test_search_engine_meilisearch.py index 8d2bb6494..d5f7d162b 100644 --- a/tests/metagpt/tools/test_search_engine_meilisearch.py +++ b/tests/metagpt/tools/test_search_engine_meilisearch.py @@ -13,7 +13,7 @@ import pytest from metagpt.logs import logger from metagpt.tools.search_engine_meilisearch import DataSource, MeilisearchEngine -MASTER_KEY = '116Qavl2qpCYNEJNv5-e0RC9kncev1nr1gt7ybEGVLk' +MASTER_KEY = "116Qavl2qpCYNEJNv5-e0RC9kncev1nr1gt7ybEGVLk" @pytest.fixture() @@ -29,7 +29,7 @@ def test_meilisearch(search_engine_server): search_engine = MeilisearchEngine(url="http://localhost:7700", token=MASTER_KEY) # 假设有一个名为"books"的数据源,包含要添加的文档库 - books_data_source = DataSource(name='books', url='https://example.com/books') + books_data_source = DataSource(name="books", url="https://example.com/books") # 假设有一个名为"documents"的文档库,包含要添加的文档 documents = [ @@ -43,4 +43,4 @@ def test_meilisearch(search_engine_server): # 添加文档库到搜索引擎 search_engine.add_documents(books_data_source, documents) - logger.info(search_engine.search('Book 1')) + logger.info(search_engine.search("Book 1")) diff --git a/tests/metagpt/tools/test_summarize.py b/tests/metagpt/tools/test_summarize.py index cf616c144..6a372defb 100644 --- a/tests/metagpt/tools/test_summarize.py +++ b/tests/metagpt/tools/test_summarize.py @@ -20,7 +20,6 @@ CASES = [ 1. 请根据上下文,对用户搜索请求进行总结性回答,不要包括与请求无关的文本 2. 以 [正文](引用链接) markdown形式在正文中**自然标注**~5个文本(如商品词或类似文本段),以便跳转 3. 回复优雅、清晰,**绝不重复文本**,行文流畅,长度居中""", - """# 上下文 [{'title': '去厦门 有哪些推荐的美食? - 知乎', 'href': 'https://www.zhihu.com/question/286901854', 'body': '知乎,中文互联网高质量的问答社区和创作者聚集的原创内容平台,于 2011 年 1 月正式上线,以「让人们更好的分享知识、经验和见解,找到自己的解答」为品牌使命。知乎凭借认真、专业、友善的社区氛围、独特的产品机制以及结构化和易获得的优质内容,聚集了中文互联网科技、商业、影视 ...'}, {'title': '厦门到底有哪些真正值得吃的美食? - 知乎', 'href': 'https://www.zhihu.com/question/38012322', 'body': '有几个特色菜在别处不太能吃到,值得一试~常点的有西多士、沙茶肉串、咕老肉(个人认为还是良山排档的更炉火纯青~),因为爱吃芋泥,每次还会点一个芋泥鸭~人均50元左右. 潮福城. 厦门这两年经营港式茶点的店越来越多,但是最经典的还是潮福城的茶点 ...'}, {'title': '超全厦门美食攻略,好吃不贵不踩雷 - 知乎 - 知乎专栏', 'href': 'https://zhuanlan.zhihu.com/p/347055615', 'body': '厦门老字号店铺,味道卫生都有保障,喜欢吃芒果的,不要错过芒果牛奶绵绵冰. 285蚝味馆 70/人. 上过《舌尖上的中国》味道不用多说,想吃地道的海鲜烧烤就来这里. 堂宴.老厦门私房菜 80/人. 非常多的明星打卡过,上过《十二道锋味》,吃厦门传统菜的好去处 ...'}, {'title': '福建名小吃||寻味厦门,十大特色名小吃,你都吃过哪几样? - 知乎', 'href': 'https://zhuanlan.zhihu.com/p/375781836', 'body': '第一期,分享厦门的特色美食。 厦门是一个风景旅游城市,许多人来到厦门,除了游览厦门独特的风景之外,最难忘的应该是厦门的特色小吃。厦门小吃多种多样,有到厦门必吃的沙茶面、米线糊、蚵仔煎、土笋冻等非常之多。那么,厦门的名小吃有哪些呢?'}, {'title': '大家如果去厦门旅游的话,好吃的有很多,但... 来自庄时利和 - 微博', 'href': 'https://weibo.com/1728715190/MEAwzscRT', 'body': '大家如果去厦门旅游的话,好吃的有很多,但如果只选一样的话,我个人会选择莲花煎蟹。 靠海吃海,吃蟹对于闽南人来说是很平常的一件事。 厦门传统的做法多是清蒸或水煮,上世纪八十年代有一同安人在厦门的莲花公园旁,摆摊做起了煎蟹的生意。'}, {'title': '厦门美食,厦门美食攻略,厦门旅游美食攻略 - 马蜂窝', 'href': 'https://www.mafengwo.cn/cy/10132/gonglve.html', 'body': '醉壹号海鲜大排档 (厦门美食地标店) No.3. 哆啦Eanny 的最新点评:. 环境 挺复古的闽南风情,花砖地板,一楼有海鲜自己点菜,二楼室内位置,三楼露天位置,环境挺不错的。. 苦螺汤,看起来挺清的,螺肉吃起来很脆。. 姜... 5.0 分. 482 条用户点评.'}, {'title': '厦门超强中山路小吃合集,29家本地人推荐的正宗美食 - 马蜂窝', 'href': 'https://www.mafengwo.cn/gonglve/ziyouxing/176485.html', 'body': '莲欢海蛎煎. 提到厦门就想到海蛎煎,而这家位于中山路局口街的莲欢海蛎煎是实打实的好吃!. ·局口街老巷之中,全室外环境,吃的就是这种感觉。. ·取名"莲欢",是希望妻子每天开心。. 新鲜的食材,实在的用料,这样的用心也定能讨食客欢心。. ·海蛎又 ...'}, {'title': '厦门市 10 大餐厅- Tripadvisor', 'href': 'https://cn.tripadvisor.com/Restaurants-g297407-Xiamen_Fujian.html', 'body': '厦门市餐厅:在Tripadvisor查看中国厦门市餐厅的点评,并以价格、地点及更多选项进行搜索。 ... "牛排太好吃了啊啊啊" ... "厦门地区最老品牌最有口碑的潮州菜餐厅" ...'}, {'title': '#福建10条美食街简直不要太好吃#每到一... 来自新浪厦门 - 微博', 'href': 'https://weibo.com/1740522895/MF1lY7W4n', 'body': '福建的这10条美食街,你一定不能错过!福州师大学生街、福州达明路美食街、厦门八市、漳州古城老街、宁德老南门电影院美食集市、龙岩中山路美食街、三明龙岗夜市、莆田金鼎夜市、莆田玉湖夜市、南平嘉禾美食街。世间万事皆难,唯有美食可以治愈一切。'}, {'title': '厦门这50家餐厅最值得吃 - 腾讯新闻', 'href': 'https://new.qq.com/rain/a/20200114A09HJT00', 'body': '没有什么事是一顿辣解决不了的! 创意辣、川湘辣、温柔辣、异域辣,芙蓉涧的菜能把辣椒玩出花来! ... 早在2005年,这家老牌的东南亚餐厅就开在厦门莲花了,在许多老厦门的心中,都觉得这里有全厦门最好吃的咖喱呢。 ...'}, {'title': '好听的美食?又好听又好吃的食物有什么? - 哔哩哔哩', 'href': 'https://www.bilibili.com/read/cv23430069/', 'body': '专栏 / 好听的美食?又好听又好吃的食物有什么? 又好听又好吃的食物有什么? 2023-05-02 18:01 --阅读 · --喜欢 · --评论'}] @@ -31,7 +30,7 @@ CASES = [ 你是专业管家团队的一员,会给出有帮助的建议 1. 请根据上下文,对用户搜索请求进行总结性回答,不要包括与请求无关的文本 2. 以 [正文](引用链接) markdown形式在正文中**自然标注**3-5个文本(如商品词或类似文本段),以便跳转 -3. 回复优雅、清晰,**绝不重复文本**,行文流畅,长度居中""" +3. 回复优雅、清晰,**绝不重复文本**,行文流畅,长度居中""", ] diff --git a/tests/metagpt/tools/test_translate.py b/tests/metagpt/tools/test_translate.py index 47a9034a5..024bda3ca 100644 --- a/tests/metagpt/tools/test_translate.py +++ b/tests/metagpt/tools/test_translate.py @@ -16,7 +16,7 @@ from metagpt.tools.translator import Translator def test_translate(llm_api): poetries = [ ("Let life be beautiful like summer flowers", "花"), - ("The ancient Chinese poetries are all songs.", "中国") + ("The ancient Chinese poetries are all songs.", "中国"), ] for i, j in poetries: prompt = Translator.translate_prompt(i) diff --git a/tests/metagpt/tools/test_ut_generator.py b/tests/metagpt/tools/test_ut_generator.py index 6f29999d4..2ae94885f 100644 --- a/tests/metagpt/tools/test_ut_generator.py +++ b/tests/metagpt/tools/test_ut_generator.py @@ -16,8 +16,12 @@ class TestUTWriter: tags = ["测试"] # "智能合同导入", "律师审查", "ai合同审查", "草拟合同&律师在线审查", "合同审批", "履约管理", "签约公司"] # 这里在文件中手动加入了两个测试标签的API - utg = UTGenerator(swagger_file=swagger_file, ut_py_path=UT_PY_PATH, questions_path=API_QUESTIONS_PATH, - template_prefix=YFT_PROMPT_PREFIX) + utg = UTGenerator( + swagger_file=swagger_file, + ut_py_path=UT_PY_PATH, + questions_path=API_QUESTIONS_PATH, + template_prefix=YFT_PROMPT_PREFIX, + ) ret = utg.generate_ut(include_tags=tags) # 后续加入对文件生成内容与数量的检验 assert ret diff --git a/tests/metagpt/utils/test_code_parser.py b/tests/metagpt/utils/test_code_parser.py index 707b558e1..6b7349cd9 100644 --- a/tests/metagpt/utils/test_code_parser.py +++ b/tests/metagpt/utils/test_code_parser.py @@ -131,10 +131,10 @@ class TestCodeParser: def test_parse_file_list(self, parser, text): result = parser.parse_file_list("Task list", text) print(result) - assert result == ['task1', 'task2'] + assert result == ["task1", "task2"] -if __name__ == '__main__': +if __name__ == "__main__": t = TestCodeParser() t.test_parse_file_list(CodeParser(), t_text) # TestCodeParser.test_parse_file_list() diff --git a/tests/metagpt/utils/test_common.py b/tests/metagpt/utils/test_common.py index ec4443175..d3837ca8f 100644 --- a/tests/metagpt/utils/test_common.py +++ b/tests/metagpt/utils/test_common.py @@ -16,12 +16,12 @@ from metagpt.const import get_project_root class TestGetProjectRoot: def change_etc_dir(self): # current_directory = Path.cwd() - abs_root = '/etc' + abs_root = "/etc" os.chdir(abs_root) def test_get_project_root(self): project_root = get_project_root() - assert project_root.name == 'metagpt' + assert project_root.name == "metagpt" def test_get_root_exception(self): with pytest.raises(Exception) as exc_info: diff --git a/tests/metagpt/utils/test_config.py b/tests/metagpt/utils/test_config.py index 558a4e5a4..b68a535f9 100644 --- a/tests/metagpt/utils/test_config.py +++ b/tests/metagpt/utils/test_config.py @@ -20,12 +20,12 @@ def test_config_class_is_singleton(): def test_config_class_get_key_exception(): with pytest.raises(Exception) as exc_info: config = Config() - config.get('wtf') + config.get("wtf") assert str(exc_info.value) == "Key 'wtf' not found in environment variables or in the YAML file" def test_config_yaml_file_not_exists(): - config = Config('wtf.yaml') + config = Config("wtf.yaml") with pytest.raises(Exception) as exc_info: - config.get('OPENAI_BASE_URL') + config.get("OPENAI_BASE_URL") assert str(exc_info.value) == "Key 'OPENAI_BASE_URL' not found in environment variables or in the YAML file" diff --git a/tests/metagpt/utils/test_custom_aio_session.py b/tests/metagpt/utils/test_custom_aio_session.py index 3a8a7bf7e..e2876e4b8 100644 --- a/tests/metagpt/utils/test_custom_aio_session.py +++ b/tests/metagpt/utils/test_custom_aio_session.py @@ -10,12 +10,12 @@ from metagpt.provider.openai_api import OpenAIGPTAPI async def try_hello(api): - batch = [[{'role': 'user', 'content': 'hello'}]] + batch = [[{"role": "user", "content": "hello"}]] results = await api.acompletion_batch_text(batch) return results async def aask_batch(api: OpenAIGPTAPI): - results = await api.aask_batch(['hi', 'write python hello world.']) + results = await api.aask_batch(["hi", "write python hello world."]) logger.info(results) return results diff --git a/tests/metagpt/utils/test_file.py b/tests/metagpt/utils/test_file.py index b30e6be93..83e317213 100644 --- a/tests/metagpt/utils/test_file.py +++ b/tests/metagpt/utils/test_file.py @@ -15,12 +15,11 @@ from metagpt.utils.file import File @pytest.mark.asyncio @pytest.mark.parametrize( ("root_path", "filename", "content"), - [(Path("/code/MetaGPT/data/tutorial_docx/2023-09-07_17-05-20"), "test.md", "Hello World!")] + [(Path("/code/MetaGPT/data/tutorial_docx/2023-09-07_17-05-20"), "test.md", "Hello World!")], ) async def test_write_and_read_file(root_path: Path, filename: str, content: bytes): - full_file_name = await File.write(root_path=root_path, filename=filename, content=content.encode('utf-8')) + full_file_name = await File.write(root_path=root_path, filename=filename, content=content.encode("utf-8")) assert isinstance(full_file_name, Path) assert root_path / filename == full_file_name file_data = await File.read(full_file_name) assert file_data.decode("utf-8") == content - diff --git a/tests/metagpt/utils/test_output_parser.py b/tests/metagpt/utils/test_output_parser.py index 4e362f9f7..7a3aedbe8 100644 --- a/tests/metagpt/utils/test_output_parser.py +++ b/tests/metagpt/utils/test_output_parser.py @@ -14,17 +14,17 @@ from metagpt.utils.common import OutputParser def test_parse_blocks(): test_text = "##block1\nThis is block 1.\n##block2\nThis is block 2." - expected_result = {'block1': 'This is block 1.', 'block2': 'This is block 2.'} + expected_result = {"block1": "This is block 1.", "block2": "This is block 2."} assert OutputParser.parse_blocks(test_text) == expected_result def test_parse_code(): test_text = "```python\nprint('Hello, world!')```" expected_result = "print('Hello, world!')" - assert OutputParser.parse_code(test_text, 'python') == expected_result + assert OutputParser.parse_code(test_text, "python") == expected_result with pytest.raises(Exception): - OutputParser.parse_code(test_text, 'java') + OutputParser.parse_code(test_text, "java") def test_parse_python_code(): @@ -45,13 +45,13 @@ def test_parse_python_code(): def test_parse_str(): test_text = "name = 'Alice'" - expected_result = 'Alice' + expected_result = "Alice" assert OutputParser.parse_str(test_text) == expected_result def test_parse_file_list(): test_text = "files=['file1', 'file2', 'file3']" - expected_result = ['file1', 'file2', 'file3'] + expected_result = ["file1", "file2", "file3"] assert OutputParser.parse_file_list(test_text) == expected_result with pytest.raises(Exception): @@ -60,7 +60,7 @@ def test_parse_file_list(): def test_parse_data(): test_data = "##block1\n```python\nprint('Hello, world!')\n```\n##block2\nfiles=['file1', 'file2', 'file3']" - expected_result = {'block1': "print('Hello, world!')", 'block2': ['file1', 'file2', 'file3']} + expected_result = {"block1": "print('Hello, world!')", "block2": ["file1", "file2", "file3"]} assert OutputParser.parse_data(test_data) == expected_result @@ -103,9 +103,11 @@ def test_parse_data(): None, Exception, ), - ] + ], ) -def test_extract_struct(text: str, data_type: Union[type(list), type(dict)], parsed_data: Union[list, dict], expected_exception): +def test_extract_struct( + text: str, data_type: Union[type(list), type(dict)], parsed_data: Union[list, dict], expected_exception +): def case(): resp = OutputParser.extract_struct(text, data_type) assert resp == parsed_data @@ -117,7 +119,7 @@ def test_extract_struct(text: str, data_type: Union[type(list), type(dict)], par case() -if __name__ == '__main__': +if __name__ == "__main__": t_text = ''' ## Required Python third-party packages ```python @@ -216,7 +218,7 @@ We need clarification on how the high score should be stored. Should it persist "Requirement Pool": (List[Tuple[str, str]], ...), "Anything UNCLEAR": (str, ...), } - t_text1 = '''## Original Requirements: + t_text1 = """## Original Requirements: The boss wants to create a web-based version of the game "Fly Bird". @@ -284,7 +286,7 @@ The product should be a web-based version of the game "Fly Bird" that is engagin ## Anything UNCLEAR: There are no unclear points. - ''' + """ d = OutputParser.parse_data_with_mapping(t_text1, OUTPUT_MAPPING) import json diff --git a/tests/metagpt/utils/test_parse_html.py b/tests/metagpt/utils/test_parse_html.py index 42be416a6..dd15bd80b 100644 --- a/tests/metagpt/utils/test_parse_html.py +++ b/tests/metagpt/utils/test_parse_html.py @@ -52,9 +52,11 @@ PAGE = """ """ -CONTENT = 'This is a HeadingThis is a paragraph witha linkand someemphasizedtext.Item 1Item 2Item 3Numbered Item 1Numbered '\ -'Item 2Numbered Item 3Header 1Header 2Row 1, Cell 1Row 1, Cell 2Row 2, Cell 1Row 2, Cell 2Name:Email:SubmitThis is a div '\ -'with a class "box".a link' +CONTENT = ( + "This is a HeadingThis is a paragraph witha linkand someemphasizedtext.Item 1Item 2Item 3Numbered Item 1Numbered " + "Item 2Numbered Item 3Header 1Header 2Row 1, Cell 1Row 1, Cell 2Row 2, Cell 1Row 2, Cell 2Name:Email:SubmitThis is a div " + 'with a class "box".a link' +) def test_web_page(): diff --git a/tests/metagpt/utils/test_pycst.py b/tests/metagpt/utils/test_pycst.py index 07352eac2..9cf876611 100644 --- a/tests/metagpt/utils/test_pycst.py +++ b/tests/metagpt/utils/test_pycst.py @@ -1,6 +1,6 @@ from metagpt.utils import pycst -code = ''' +code = """ #!/usr/bin/env python # -*- coding: utf-8 -*- from typing import overload @@ -24,7 +24,7 @@ class Person: def greet(self): return f"Hello, my name is {self.name} and I am {self.age} years old." -''' +""" documented_code = ''' """ diff --git a/tests/metagpt/utils/test_text.py b/tests/metagpt/utils/test_text.py index 0caf8abaa..7003c7767 100644 --- a/tests/metagpt/utils/test_text.py +++ b/tests/metagpt/utils/test_text.py @@ -29,7 +29,7 @@ def _paragraphs(n): (_msgs(), "gpt-4", "Hello," * 1000, 2000, 2), (_msgs(), "gpt-4-32k", "System", 4000, 14), (_msgs(), "gpt-4-32k", "Hello," * 2000, 4000, 12), - ] + ], ) def test_reduce_message_length(msgs, model_name, system_text, reserved, expected): assert len(reduce_message_length(msgs, model_name, system_text, reserved)) / (len("Hello,")) / 1000 == expected @@ -42,7 +42,7 @@ def test_reduce_message_length(msgs, model_name, system_text, reserved, expected (" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-16k", "System", 3000, 1), (" ".join("Hello World." for _ in range(4000)), "Prompt: {}", "gpt-4", "System", 2000, 2), (" ".join("Hello World." for _ in range(8000)), "Prompt: {}", "gpt-4-32k", "System", 4000, 1), - ] + ], ) def test_generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved, expected): ret = list(generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved)) @@ -58,7 +58,7 @@ def test_generate_prompt_chunk(text, prompt_template, model_name, system_text, r ("......", ".", 2, ["...", "..."]), ("......", ".", 3, ["..", "..", ".."]), (".......", ".", 2, ["....", "..."]), - ] + ], ) def test_split_paragraph(paragraph, sep, count, expected): ret = split_paragraph(paragraph, sep, count) @@ -71,7 +71,7 @@ def test_split_paragraph(paragraph, sep, count, expected): ("Hello\\nWorld", "Hello\nWorld"), ("Hello\\tWorld", "Hello\tWorld"), ("Hello\\u0020World", "Hello World"), - ] + ], ) def test_decode_unicode_escape(text, expected): assert decode_unicode_escape(text) == expected