From 97ab38ac7c4a19fc80ea56ca5ac2bc89f448f222 Mon Sep 17 00:00:00 2001 From: Stitch-z <284618289@qq.com> Date: Wed, 20 Sep 2023 12:02:46 +0800 Subject: [PATCH] Resolve PR issues --- examples/write_tutorial.py | 1 + metagpt/actions/research.py | 6 +++--- metagpt/actions/write_tutorial.py | 2 +- metagpt/utils/common.py | 10 +++++----- tests/metagpt/utils/test_output_parser.py | 16 ++++++++-------- 5 files changed, 18 insertions(+), 17 deletions(-) diff --git a/examples/write_tutorial.py b/examples/write_tutorial.py index 73a9c71b7..71ece5527 100644 --- a/examples/write_tutorial.py +++ b/examples/write_tutorial.py @@ -18,3 +18,4 @@ async def main(): if __name__ == '__main__': asyncio.run(main()) + diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py index 2dea28e2e..49a981e86 100644 --- a/metagpt/actions/research.py +++ b/metagpt/actions/research.py @@ -111,7 +111,7 @@ class CollectLinks(Action): system_text = system_text if system_text else RESEARCH_TOPIC_SYSTEM.format(topic=topic) keywords = await self._aask(SEARCH_TOPIC_PROMPT, [system_text]) try: - keywords = OutputParser.extract_struct(keywords, "list") + keywords = OutputParser.extract_struct(keywords, list) keywords = parse_obj_as(list[str], keywords) except Exception as e: logger.exception(f"fail to get keywords related to the research topic \"{topic}\" for {e}") @@ -131,7 +131,7 @@ class CollectLinks(Action): logger.debug(prompt) queries = await self._aask(prompt, [system_text]) try: - queries = OutputParser.extract_struct(queries, "list") + queries = OutputParser.extract_struct(queries, list) queries = parse_obj_as(list[str], queries) except Exception as e: logger.exception(f"fail to break down the research question due to {e}") @@ -159,7 +159,7 @@ class CollectLinks(Action): logger.debug(prompt) indices = await self._aask(prompt) try: - indices = OutputParser.extract_struct(indices, "list") + indices = OutputParser.extract_struct(indices, list) assert all(isinstance(i, int) for i in indices) except Exception as e: logger.exception(f"fail to rank results for {e}") diff --git a/metagpt/actions/write_tutorial.py b/metagpt/actions/write_tutorial.py index 95f85d540..23e3560e8 100644 --- a/metagpt/actions/write_tutorial.py +++ b/metagpt/actions/write_tutorial.py @@ -37,7 +37,7 @@ class WriteDirectory(Action): """ prompt = DIRECTORY_PROMPT.format(topic=topic, language=self.language) resp = await self._aask(prompt=prompt) - return OutputParser.extract_struct(resp, "dict") + return OutputParser.extract_struct(resp, dict) class WriteContent(Action): diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 37a4dbdb6..d0ab7e81d 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -11,7 +11,7 @@ import inspect import os import platform import re -from typing import List, Tuple +from typing import List, Tuple, Union from metagpt.logs import logger @@ -152,7 +152,7 @@ class OutputParser: return parsed_data @classmethod - def extract_struct(cls, text: str, data_type: str) -> Tuple[list, dict]: + def extract_struct(cls, text: str, data_type: Union[type(list), type(dict)]) -> Union[list, dict]: """Extracts and parses a specified type of structure (dictionary or list) from the given text. The text only contains a list or dictionary, which may have nested structures. @@ -176,8 +176,8 @@ class OutputParser: >>> # Output: {"x": 1, "y": {"a": 2, "b": {"c": 3}}} """ # Find the first "[" or "{" and the last "]" or "}" - start_index = text.find("[" if data_type == "list" else "{") - end_index = text.rfind("]" if data_type == "list" else "}") + start_index = text.find("[" if data_type is list else "{") + end_index = text.rfind("]" if data_type is list else "}") if start_index != -1 and end_index != -1: # Extract the structure part @@ -188,7 +188,7 @@ class OutputParser: result = ast.literal_eval(structure_text) # Ensure the result matches the specified data type - if (data_type == "list" and isinstance(result, list)) or (data_type == "dict" and isinstance(result, dict)): + if isinstance(result, list) or isinstance(result, dict): return result raise ValueError(f"The extracted structure is not a {data_type}.") diff --git a/tests/metagpt/utils/test_output_parser.py b/tests/metagpt/utils/test_output_parser.py index c5ae73ac9..2b706efc4 100644 --- a/tests/metagpt/utils/test_output_parser.py +++ b/tests/metagpt/utils/test_output_parser.py @@ -5,7 +5,7 @@ @Author : chengmaoyu @File : test_output_parser.py """ -from typing import List, Tuple +from typing import List, Tuple, Union import pytest @@ -69,43 +69,43 @@ def test_parse_data(): [ ( """xxx [1, 2, ["a", "b", [3, 4]], {"x": 5, "y": [6, 7]}] xxx""", - "list", + list, [1, 2, ["a", "b", [3, 4]], {"x": 5, "y": [6, 7]}], None, ), ( """xxx ["1", "2", "3"] xxx \n xxx \t xx""", - "list", + list, ["1", "2", "3"], None, ), ( """{"title": "a", "directory": {"sub_dir1": ["title1, title2"]}, "sub_dir2": [1, 2]}""", - "dict", + dict, {"title": "a", "directory": {"sub_dir1": ["title1, title2"]}, "sub_dir2": [1, 2]}, None, ), ( """xxx {"title": "x", \n \t "directory": ["x", \n "y"]} xxx \n xxx \t xx""", - "dict", + dict, {"title": "x", "directory": ["x", "y"]}, None, ), ( """xxx xx""", - "list", + list, None, Exception, ), ( """xxx [1, 2, []xx""", - "list", + list, None, Exception, ), ] ) -def test_extract_struct(text: str, data_type: str, parsed_data: list, expected_exception): +def test_extract_struct(text: str, data_type: Union[type(list), type(dict)], parsed_data: Union[list, dict], expected_exception): def case(): resp = OutputParser.extract_struct(text, data_type) assert resp == parsed_data