From 97be4311270edd95688fa75c2587820b46585203 Mon Sep 17 00:00:00 2001 From: Stitch-z <284618289@qq.com> Date: Tue, 19 Sep 2023 16:26:33 +0800 Subject: [PATCH 1/2] Fix research action bug && Optimize universal file operation tools --- examples/write_tutorial.py | 2 +- metagpt/actions/research.py | 7 +-- metagpt/actions/write_tutorial.py | 33 ++------------ metagpt/utils/common.py | 47 ++++++++++++++++++++ metagpt/utils/file.py | 48 ++++++++++++++++++-- tests/metagpt/utils/test_file.py | 8 ++-- tests/metagpt/utils/test_output_parser.py | 53 +++++++++++++++++++++++ 7 files changed, 155 insertions(+), 43 deletions(-) diff --git a/examples/write_tutorial.py b/examples/write_tutorial.py index 167f3eb7c..73a9c71b7 100644 --- a/examples/write_tutorial.py +++ b/examples/write_tutorial.py @@ -17,4 +17,4 @@ async def main(): if __name__ == '__main__': - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py index 81eb876dd..2dea28e2e 100644 --- a/metagpt/actions/research.py +++ b/metagpt/actions/research.py @@ -13,6 +13,7 @@ from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.tools.search_engine import SearchEngine from metagpt.tools.web_browser_engine import WebBrowserEngine, WebBrowserEngineType +from metagpt.utils.common import OutputParser from metagpt.utils.text import generate_prompt_chunk, reduce_message_length LANG_PROMPT = "Please respond in {language}." @@ -110,7 +111,7 @@ class CollectLinks(Action): system_text = system_text if system_text else RESEARCH_TOPIC_SYSTEM.format(topic=topic) keywords = await self._aask(SEARCH_TOPIC_PROMPT, [system_text]) try: - keywords = json.loads(keywords) + keywords = OutputParser.extract_struct(keywords, "list") keywords = parse_obj_as(list[str], keywords) except Exception as e: logger.exception(f"fail to get keywords related to the research topic \"{topic}\" for {e}") @@ -130,7 +131,7 @@ class CollectLinks(Action): logger.debug(prompt) queries = await self._aask(prompt, [system_text]) try: - queries = json.loads(queries) + queries = OutputParser.extract_struct(queries, "list") queries = parse_obj_as(list[str], queries) except Exception as e: logger.exception(f"fail to break down the research question due to {e}") @@ -158,7 +159,7 @@ class CollectLinks(Action): logger.debug(prompt) indices = await self._aask(prompt) try: - indices = json.loads(indices) + indices = OutputParser.extract_struct(indices, "list") assert all(isinstance(i, int) for i in indices) except Exception as e: logger.exception(f"fail to rank results for {e}") diff --git a/metagpt/actions/write_tutorial.py b/metagpt/actions/write_tutorial.py index b23fc2ad4..95f85d540 100644 --- a/metagpt/actions/write_tutorial.py +++ b/metagpt/actions/write_tutorial.py @@ -6,12 +6,12 @@ @File : tutorial_assistant.py @Describe : Actions of the tutorial assistant, including writing directories and document content. """ -import json + from typing import Dict from metagpt.actions import Action -from metagpt.logs import logger from metagpt.prompts.tutorial_assistant import DIRECTORY_PROMPT, CONTENT_PROMPT +from metagpt.utils.common import OutputParser class WriteDirectory(Action): @@ -26,33 +26,6 @@ class WriteDirectory(Action): super().__init__(name, *args, **kwargs) self.language = language - @staticmethod - async def _handle_resp(resp: str) -> Dict: - """Process string results and convert them to JSON format. - - Args: - resp: The directory results returned by gpt. - - Returns: - The parsed dictionary, such as {"title": "xxx", "directory": [{"dir 1": ["sub dir 1", "sub dir 2"]}]}. - - Raises: - Exception: If no matching dictionary section is found. - json.JSONDecodeError: If the dictionary part cannot be parsed as JSON. - """ - start = resp.find('{') - end = resp.rfind('}') - if start != -1 and end != -1 and end > start: - directory_str = resp[start:end + 1] - logger.info(f"Successfully parsed json: {str(directory_str)}") - try: - return json.loads(directory_str) - except json.JSONDecodeError as e: - logger.error(f"Json parsing error: {e}") - raise e - else: - raise Exception("No matching dictionary section found.") - async def run(self, topic: str, *args, **kwargs) -> Dict: """Execute the action to generate a tutorial directory according to the topic. @@ -64,7 +37,7 @@ class WriteDirectory(Action): """ prompt = DIRECTORY_PROMPT.format(topic=topic, language=self.language) resp = await self._aask(prompt=prompt) - return await self._handle_resp(resp) + return OutputParser.extract_struct(resp, "dict") class WriteContent(Action): diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 5f94de066..37a4dbdb6 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -151,6 +151,53 @@ class OutputParser: parsed_data[block] = content return parsed_data + @classmethod + def extract_struct(cls, text: str, data_type: str) -> Tuple[list, dict]: + """Extracts and parses a specified type of structure (dictionary or list) from the given text. + The text only contains a list or dictionary, which may have nested structures. + + Args: + text: The text containing the structure (dictionary or list). + data_type: The data type to extract, can be "list" or "dict". + + Returns: + - If extraction and parsing are successful, it returns the corresponding data structure (list or dictionary). + - If extraction fails or parsing encounters an error, it throw an exception. + + Examples: + >>> text = 'xxx [1, 2, ["a", "b", [3, 4]], {"x": 5, "y": [6, 7]}] xxx' + >>> result_list = OutputParser.extract_struct(text, "list") + >>> print(result_list) + >>> # Output: [1, 2, ["a", "b", [3, 4]], {"x": 5, "y": [6, 7]}] + + >>> text = 'xxx {"x": 1, "y": {"a": 2, "b": {"c": 3}}} xxx' + >>> result_dict = OutputParser.extract_struct(text, "dict") + >>> print(result_dict) + >>> # Output: {"x": 1, "y": {"a": 2, "b": {"c": 3}}} + """ + # Find the first "[" or "{" and the last "]" or "}" + start_index = text.find("[" if data_type == "list" else "{") + end_index = text.rfind("]" if data_type == "list" else "}") + + if start_index != -1 and end_index != -1: + # Extract the structure part + structure_text = text[start_index:end_index + 1] + + try: + # Attempt to convert the text to a Python data type using ast.literal_eval + result = ast.literal_eval(structure_text) + + # Ensure the result matches the specified data type + if (data_type == "list" and isinstance(result, list)) or (data_type == "dict" and isinstance(result, dict)): + return result + + raise ValueError(f"The extracted structure is not a {data_type}.") + + except (ValueError, SyntaxError) as e: + raise Exception(f"Error while extracting and parsing the {data_type}: {e}") + else: + raise Exception(f"No {data_type} found in the text.") + class CodeParser: diff --git a/metagpt/utils/file.py b/metagpt/utils/file.py index 5aca2a0e5..738b5a049 100644 --- a/metagpt/utils/file.py +++ b/metagpt/utils/file.py @@ -15,14 +15,17 @@ from metagpt.logs import logger class File: """A general util for file operations.""" + CHUNK_SIZE = 64 * 1024 + @classmethod - async def write(cls, root_path: Path, filename: str, content: bytes) -> Path: - """Write the file content to the local specified path. + async def write(cls, root_path: Path, filename: str, content: bytes, chunk_size: int = None) -> Path: + """Partitioning write the file content to the local specified path. Args: root_path: The root path of file, such as "/data". filename: The name of file, such as "test.txt". content: The binary content of file. + chunk_size: The size of each chunk in bytes (default is 64kb). Returns: The full filename of file, such as "/data/test.txt". @@ -31,12 +34,49 @@ class File: Exception: If an unexpected error occurs during the file writing process. """ try: + chunk_size = chunk_size or cls.CHUNK_SIZE root_path.mkdir(parents=True, exist_ok=True) full_path = root_path / filename async with aiofiles.open(full_path, mode="wb") as writer: - await writer.write(content) + for i in range(0, len(content), chunk_size): + chunk = content[i:i + chunk_size] + await writer.write(chunk) + # Flush the buffer to ensure data is written immediately + await writer.flush() logger.info(f"Successfully write file: {full_path}") return full_path except Exception as e: logger.error(f"Error writing file: {e}") - raise e \ No newline at end of file + raise e + + @classmethod + async def read(cls, file_path: Path, chunk_size: int = None) -> bytes: + """Partitioning read the file content from the local specified path. + + Args: + file_path: The full file name of file, such as "/data/test.txt". + chunk_size: The size of each chunk in bytes (default is 64kb). + + Returns: + The binary content of file. + + Raises: + Exception: If an unexpected error occurs during the file reading process. + """ + try: + if not file_path.exists(): + raise FileNotFoundError(f"File not found, path is '{file_path}'") + chunk_size = chunk_size or cls.CHUNK_SIZE + async with aiofiles.open(file_path, mode="rb") as reader: + content = bytes() + while True: + chunk = await reader.read(chunk_size) + if not chunk: + break + content += chunk + logger.info(f"Successfully read file, the size of file: {len(content)}") + return content + except Exception as e: + logger.error(f"Error reading file: {e}") + raise e + diff --git a/tests/metagpt/utils/test_file.py b/tests/metagpt/utils/test_file.py index a9f1a353d..2f224e558 100644 --- a/tests/metagpt/utils/test_file.py +++ b/tests/metagpt/utils/test_file.py @@ -7,7 +7,6 @@ """ from pathlib import Path -import aiofiles import pytest from metagpt.utils.file import File @@ -18,10 +17,9 @@ from metagpt.utils.file import File ("root_path", "filename", "content"), [(Path("/code/MetaGPT/data/tutorial_docx/2023-09-07_17-05-20"), "test.md", "Hello World!")] ) -async def test_write_file(root_path: Path, filename: str, content: bytes): +async def test_write_and_read_file(root_path: Path, filename: str, content: bytes): full_file_name = await File.write(root_path=root_path, filename=filename, content=content.encode('utf-8')) assert isinstance(full_file_name, Path) assert root_path / filename == full_file_name - async with aiofiles.open(full_file_name, mode="r") as reader: - body = await reader.read() - assert body == content \ No newline at end of file + file_data = await File.read(full_file_name) + assert file_data.decode("utf-8") == content diff --git a/tests/metagpt/utils/test_output_parser.py b/tests/metagpt/utils/test_output_parser.py index c56cff6fa..e779d6647 100644 --- a/tests/metagpt/utils/test_output_parser.py +++ b/tests/metagpt/utils/test_output_parser.py @@ -64,6 +64,59 @@ def test_parse_data(): assert OutputParser.parse_data(test_data) == expected_result +@pytest.mark.parametrize( + ("text", "data_type", "parsed_data", "expected_exception"), + [ + ( + """xxx [1, 2, ["a", "b", [3, 4]], {"x": 5, "y": [6, 7]}] xxx""", + "list", + [1, 2, ["a", "b", [3, 4]], {"x": 5, "y": [6, 7]}], + None, + ), + ( + """xxx ["1", "2", "3"] xxx \n xxx \t xx""", + "list", + ["1", "2", "3"], + None, + ), + ( + """{"title": "a", "directory": {"sub_dir1": ["title1, title2"]}, "sub_dir2": [1, 2]}""", + "dict", + {"title": "a", "directory": {"sub_dir1": ["title1, title2"]}, "sub_dir2": [1, 2]}, + None, + ), + ( + """xxx {"title": "x", \n \t "directory": ["x", \n "y"]} xxx \n xxx \t xx""", + "dict", + {"title": "x", "directory": ["x", "y"]}, + None, + ), + ( + """xxx xx""", + "list", + None, + Exception, + ), + ( + """xxx [1, 2, []xx""", + "list", + None, + Exception, + ), + ] +) +def test_extract_list_or_dict(text: str, data_type: str, parsed_data: list, expected_exception): + def case(): + resp = OutputParser.extract_struct(text, data_type) + assert resp == parsed_data + + if expected_exception: + with pytest.raises(expected_exception): + case() + else: + case() + + if __name__ == '__main__': t_text = ''' ## Required Python third-party packages From c99f4bffe93b4a8f078be64a88d29a76a3d1ae2e Mon Sep 17 00:00:00 2001 From: Stitch-z <284618289@qq.com> Date: Tue, 19 Sep 2023 16:39:24 +0800 Subject: [PATCH 2/2] Fix research action bug && Optimize universal file operation tools --- tests/metagpt/utils/test_output_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metagpt/utils/test_output_parser.py b/tests/metagpt/utils/test_output_parser.py index e779d6647..c5ae73ac9 100644 --- a/tests/metagpt/utils/test_output_parser.py +++ b/tests/metagpt/utils/test_output_parser.py @@ -105,7 +105,7 @@ def test_parse_data(): ), ] ) -def test_extract_list_or_dict(text: str, data_type: str, parsed_data: list, expected_exception): +def test_extract_struct(text: str, data_type: str, parsed_data: list, expected_exception): def case(): resp = OutputParser.extract_struct(text, data_type) assert resp == parsed_data