diff --git a/examples/write_tutorial.py b/examples/write_tutorial.py index 167f3eb7c..71ece5527 100644 --- a/examples/write_tutorial.py +++ b/examples/write_tutorial.py @@ -17,4 +17,5 @@ async def main(): if __name__ == '__main__': - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) + diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py index 81eb876dd..49a981e86 100644 --- a/metagpt/actions/research.py +++ b/metagpt/actions/research.py @@ -13,6 +13,7 @@ from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.tools.search_engine import SearchEngine from metagpt.tools.web_browser_engine import WebBrowserEngine, WebBrowserEngineType +from metagpt.utils.common import OutputParser from metagpt.utils.text import generate_prompt_chunk, reduce_message_length LANG_PROMPT = "Please respond in {language}." @@ -110,7 +111,7 @@ class CollectLinks(Action): system_text = system_text if system_text else RESEARCH_TOPIC_SYSTEM.format(topic=topic) keywords = await self._aask(SEARCH_TOPIC_PROMPT, [system_text]) try: - keywords = json.loads(keywords) + keywords = OutputParser.extract_struct(keywords, list) keywords = parse_obj_as(list[str], keywords) except Exception as e: logger.exception(f"fail to get keywords related to the research topic \"{topic}\" for {e}") @@ -130,7 +131,7 @@ class CollectLinks(Action): logger.debug(prompt) queries = await self._aask(prompt, [system_text]) try: - queries = json.loads(queries) + queries = OutputParser.extract_struct(queries, list) queries = parse_obj_as(list[str], queries) except Exception as e: logger.exception(f"fail to break down the research question due to {e}") @@ -158,7 +159,7 @@ class CollectLinks(Action): logger.debug(prompt) indices = await self._aask(prompt) try: - indices = json.loads(indices) + indices = OutputParser.extract_struct(indices, list) assert all(isinstance(i, int) for i in indices) except Exception as e: logger.exception(f"fail to rank results for {e}") diff --git a/metagpt/actions/write_tutorial.py b/metagpt/actions/write_tutorial.py index b23fc2ad4..23e3560e8 100644 --- a/metagpt/actions/write_tutorial.py +++ b/metagpt/actions/write_tutorial.py @@ -6,12 +6,12 @@ @File : tutorial_assistant.py @Describe : Actions of the tutorial assistant, including writing directories and document content. """ -import json + from typing import Dict from metagpt.actions import Action -from metagpt.logs import logger from metagpt.prompts.tutorial_assistant import DIRECTORY_PROMPT, CONTENT_PROMPT +from metagpt.utils.common import OutputParser class WriteDirectory(Action): @@ -26,33 +26,6 @@ class WriteDirectory(Action): super().__init__(name, *args, **kwargs) self.language = language - @staticmethod - async def _handle_resp(resp: str) -> Dict: - """Process string results and convert them to JSON format. - - Args: - resp: The directory results returned by gpt. - - Returns: - The parsed dictionary, such as {"title": "xxx", "directory": [{"dir 1": ["sub dir 1", "sub dir 2"]}]}. - - Raises: - Exception: If no matching dictionary section is found. - json.JSONDecodeError: If the dictionary part cannot be parsed as JSON. - """ - start = resp.find('{') - end = resp.rfind('}') - if start != -1 and end != -1 and end > start: - directory_str = resp[start:end + 1] - logger.info(f"Successfully parsed json: {str(directory_str)}") - try: - return json.loads(directory_str) - except json.JSONDecodeError as e: - logger.error(f"Json parsing error: {e}") - raise e - else: - raise Exception("No matching dictionary section found.") - async def run(self, topic: str, *args, **kwargs) -> Dict: """Execute the action to generate a tutorial directory according to the topic. @@ -64,7 +37,7 @@ class WriteDirectory(Action): """ prompt = DIRECTORY_PROMPT.format(topic=topic, language=self.language) resp = await self._aask(prompt=prompt) - return await self._handle_resp(resp) + return OutputParser.extract_struct(resp, dict) class WriteContent(Action): diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 5f94de066..d0ab7e81d 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -11,7 +11,7 @@ import inspect import os import platform import re -from typing import List, Tuple +from typing import List, Tuple, Union from metagpt.logs import logger @@ -151,6 +151,53 @@ class OutputParser: parsed_data[block] = content return parsed_data + @classmethod + def extract_struct(cls, text: str, data_type: Union[type(list), type(dict)]) -> Union[list, dict]: + """Extracts and parses a specified type of structure (dictionary or list) from the given text. + The text only contains a list or dictionary, which may have nested structures. + + Args: + text: The text containing the structure (dictionary or list). + data_type: The data type to extract, can be "list" or "dict". + + Returns: + - If extraction and parsing are successful, it returns the corresponding data structure (list or dictionary). + - If extraction fails or parsing encounters an error, it throw an exception. + + Examples: + >>> text = 'xxx [1, 2, ["a", "b", [3, 4]], {"x": 5, "y": [6, 7]}] xxx' + >>> result_list = OutputParser.extract_struct(text, "list") + >>> print(result_list) + >>> # Output: [1, 2, ["a", "b", [3, 4]], {"x": 5, "y": [6, 7]}] + + >>> text = 'xxx {"x": 1, "y": {"a": 2, "b": {"c": 3}}} xxx' + >>> result_dict = OutputParser.extract_struct(text, "dict") + >>> print(result_dict) + >>> # Output: {"x": 1, "y": {"a": 2, "b": {"c": 3}}} + """ + # Find the first "[" or "{" and the last "]" or "}" + start_index = text.find("[" if data_type is list else "{") + end_index = text.rfind("]" if data_type is list else "}") + + if start_index != -1 and end_index != -1: + # Extract the structure part + structure_text = text[start_index:end_index + 1] + + try: + # Attempt to convert the text to a Python data type using ast.literal_eval + result = ast.literal_eval(structure_text) + + # Ensure the result matches the specified data type + if isinstance(result, list) or isinstance(result, dict): + return result + + raise ValueError(f"The extracted structure is not a {data_type}.") + + except (ValueError, SyntaxError) as e: + raise Exception(f"Error while extracting and parsing the {data_type}: {e}") + else: + raise Exception(f"No {data_type} found in the text.") + class CodeParser: diff --git a/metagpt/utils/file.py b/metagpt/utils/file.py index 5aca2a0e5..f3691549b 100644 --- a/metagpt/utils/file.py +++ b/metagpt/utils/file.py @@ -15,6 +15,8 @@ from metagpt.logs import logger class File: """A general util for file operations.""" + CHUNK_SIZE = 64 * 1024 + @classmethod async def write(cls, root_path: Path, filename: str, content: bytes) -> Path: """Write the file content to the local specified path. @@ -35,8 +37,39 @@ class File: full_path = root_path / filename async with aiofiles.open(full_path, mode="wb") as writer: await writer.write(content) - logger.info(f"Successfully write file: {full_path}") + logger.debug(f"Successfully write file: {full_path}") return full_path except Exception as e: logger.error(f"Error writing file: {e}") - raise e \ No newline at end of file + raise e + + @classmethod + async def read(cls, file_path: Path, chunk_size: int = None) -> bytes: + """Partitioning read the file content from the local specified path. + + Args: + file_path: The full file name of file, such as "/data/test.txt". + chunk_size: The size of each chunk in bytes (default is 64kb). + + Returns: + The binary content of file. + + Raises: + Exception: If an unexpected error occurs during the file reading process. + """ + try: + chunk_size = chunk_size or cls.CHUNK_SIZE + async with aiofiles.open(file_path, mode="rb") as reader: + chunks = list() + while True: + chunk = await reader.read(chunk_size) + if not chunk: + break + chunks.append(chunk) + content = b''.join(chunks) + logger.debug(f"Successfully read file, the path of file: {file_path}") + return content + except Exception as e: + logger.error(f"Error reading file: {e}") + raise e + diff --git a/tests/metagpt/utils/test_file.py b/tests/metagpt/utils/test_file.py index a9f1a353d..b30e6be93 100644 --- a/tests/metagpt/utils/test_file.py +++ b/tests/metagpt/utils/test_file.py @@ -7,7 +7,6 @@ """ from pathlib import Path -import aiofiles import pytest from metagpt.utils.file import File @@ -18,10 +17,10 @@ from metagpt.utils.file import File ("root_path", "filename", "content"), [(Path("/code/MetaGPT/data/tutorial_docx/2023-09-07_17-05-20"), "test.md", "Hello World!")] ) -async def test_write_file(root_path: Path, filename: str, content: bytes): +async def test_write_and_read_file(root_path: Path, filename: str, content: bytes): full_file_name = await File.write(root_path=root_path, filename=filename, content=content.encode('utf-8')) assert isinstance(full_file_name, Path) assert root_path / filename == full_file_name - async with aiofiles.open(full_file_name, mode="r") as reader: - body = await reader.read() - assert body == content \ No newline at end of file + file_data = await File.read(full_file_name) + assert file_data.decode("utf-8") == content + diff --git a/tests/metagpt/utils/test_output_parser.py b/tests/metagpt/utils/test_output_parser.py index c56cff6fa..2b706efc4 100644 --- a/tests/metagpt/utils/test_output_parser.py +++ b/tests/metagpt/utils/test_output_parser.py @@ -5,7 +5,7 @@ @Author : chengmaoyu @File : test_output_parser.py """ -from typing import List, Tuple +from typing import List, Tuple, Union import pytest @@ -64,6 +64,59 @@ def test_parse_data(): assert OutputParser.parse_data(test_data) == expected_result +@pytest.mark.parametrize( + ("text", "data_type", "parsed_data", "expected_exception"), + [ + ( + """xxx [1, 2, ["a", "b", [3, 4]], {"x": 5, "y": [6, 7]}] xxx""", + list, + [1, 2, ["a", "b", [3, 4]], {"x": 5, "y": [6, 7]}], + None, + ), + ( + """xxx ["1", "2", "3"] xxx \n xxx \t xx""", + list, + ["1", "2", "3"], + None, + ), + ( + """{"title": "a", "directory": {"sub_dir1": ["title1, title2"]}, "sub_dir2": [1, 2]}""", + dict, + {"title": "a", "directory": {"sub_dir1": ["title1, title2"]}, "sub_dir2": [1, 2]}, + None, + ), + ( + """xxx {"title": "x", \n \t "directory": ["x", \n "y"]} xxx \n xxx \t xx""", + dict, + {"title": "x", "directory": ["x", "y"]}, + None, + ), + ( + """xxx xx""", + list, + None, + Exception, + ), + ( + """xxx [1, 2, []xx""", + list, + None, + Exception, + ), + ] +) +def test_extract_struct(text: str, data_type: Union[type(list), type(dict)], parsed_data: Union[list, dict], expected_exception): + def case(): + resp = OutputParser.extract_struct(text, data_type) + assert resp == parsed_data + + if expected_exception: + with pytest.raises(expected_exception): + case() + else: + case() + + if __name__ == '__main__': t_text = ''' ## Required Python third-party packages