From 344d87d2e6036b7364beb2d5af14bc3e8d1e8c1a Mon Sep 17 00:00:00 2001 From: didi <84363704+didiforgithub@users.noreply.github.com> Date: Tue, 22 Oct 2024 18:14:15 +0800 Subject: [PATCH] Update --- metagpt/actions/action_node.py | 2 +- metagpt/ext/aflow/benchmark/drop.py | 1 + metagpt/ext/aflow/benchmark/gsm8k.py | 1 + metagpt/ext/aflow/benchmark/hotpotqa.py | 1 + metagpt/ext/aflow/benchmark/humaneval.py | 4 +- metagpt/ext/aflow/benchmark/math.py | 1 + metagpt/ext/aflow/benchmark/mbpp.py | 2 +- metagpt/ext/aflow/benchmark/utils.py | 177 +-------------------- metagpt/ext/aflow/data/download_data.py | 1 + metagpt/ext/aflow/scripts/operator.py | 1 - metagpt/ext/aflow/scripts/utils.py | 186 ++++++++++++++++++++++- 11 files changed, 191 insertions(+), 186 deletions(-) diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 6648b6e05..c286b2fdd 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -18,8 +18,8 @@ from pydantic import BaseModel, Field, create_model, model_validator from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action_outcls_registry import register_action_outcls -from metagpt.ext.aflow.benchmark.utils import sanitize from metagpt.const import USE_CONFIG_TIMEOUT +from metagpt.ext.aflow.scripts.utils import sanitize from metagpt.llm import BaseLLM from metagpt.logs import logger from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess diff --git a/metagpt/ext/aflow/benchmark/drop.py b/metagpt/ext/aflow/benchmark/drop.py index a3aa40740..3cec5795f 100644 --- a/metagpt/ext/aflow/benchmark/drop.py +++ b/metagpt/ext/aflow/benchmark/drop.py @@ -8,6 +8,7 @@ from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fi from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark from metagpt.logs import logger + class DROPBenchmark(BaseBenchmark): def __init__(self, name: str, file_path: str, log_path: str): super().__init__(name, file_path, log_path) diff --git a/metagpt/ext/aflow/benchmark/gsm8k.py b/metagpt/ext/aflow/benchmark/gsm8k.py index 898b3cc4f..51979c0c5 100644 --- a/metagpt/ext/aflow/benchmark/gsm8k.py +++ b/metagpt/ext/aflow/benchmark/gsm8k.py @@ -10,6 +10,7 @@ from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fi from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark from metagpt.logs import logger + class GSM8KBenchmark(BaseBenchmark): def __init__(self, name: str, file_path: str, log_path: str): super().__init__(name, file_path, log_path) diff --git a/metagpt/ext/aflow/benchmark/hotpotqa.py b/metagpt/ext/aflow/benchmark/hotpotqa.py index 0086fdd04..b3bafe22b 100644 --- a/metagpt/ext/aflow/benchmark/hotpotqa.py +++ b/metagpt/ext/aflow/benchmark/hotpotqa.py @@ -8,6 +8,7 @@ from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fi from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark from metagpt.logs import logger + class HotpotQABenchmark(BaseBenchmark): def __init__(self, name: str, file_path: str, log_path: str): super().__init__(name, file_path, log_path) diff --git a/metagpt/ext/aflow/benchmark/humaneval.py b/metagpt/ext/aflow/benchmark/humaneval.py index abe44e6e4..36771ad7a 100644 --- a/metagpt/ext/aflow/benchmark/humaneval.py +++ b/metagpt/ext/aflow/benchmark/humaneval.py @@ -5,9 +5,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed -from metagpt.actions.code_sanitize import sanitize -from metagpt.logs import logger from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark +from metagpt.ext.aflow.scripts.utils import sanitize +from metagpt.logs import logger class HumanEvalBenchmark(BaseBenchmark): diff --git a/metagpt/ext/aflow/benchmark/math.py b/metagpt/ext/aflow/benchmark/math.py index 3db5e9596..61d994b69 100644 --- a/metagpt/ext/aflow/benchmark/math.py +++ b/metagpt/ext/aflow/benchmark/math.py @@ -11,6 +11,7 @@ from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fi from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark from metagpt.logs import logger + class MATHBenchmark(BaseBenchmark): def __init__(self, name: str, file_path: str, log_path: str): super().__init__(name, file_path, log_path) diff --git a/metagpt/ext/aflow/benchmark/mbpp.py b/metagpt/ext/aflow/benchmark/mbpp.py index 2d9df3745..4e446404b 100644 --- a/metagpt/ext/aflow/benchmark/mbpp.py +++ b/metagpt/ext/aflow/benchmark/mbpp.py @@ -5,8 +5,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed from metagpt.actions.code_sanitize import sanitize -from metagpt.logs import logger from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark +from metagpt.logs import logger class MBPPBenchmark(BaseBenchmark): diff --git a/metagpt/ext/aflow/benchmark/utils.py b/metagpt/ext/aflow/benchmark/utils.py index 37d655d51..944fde6be 100644 --- a/metagpt/ext/aflow/benchmark/utils.py +++ b/metagpt/ext/aflow/benchmark/utils.py @@ -9,14 +9,8 @@ import json import os -import ast -import traceback -import numpy as np -from enum import Enum -from typing import Dict, Generator, List, Optional, Set, Tuple -import tree_sitter_python -from tree_sitter import Language, Node, Parser +import numpy as np def generate_random_indices(n, n_samples, test=False): @@ -82,172 +76,3 @@ def log_mismatch(problem, expected_output, prediction, predicted_number, path): # 将数据写回到log.json文件 with open(log_file, "w", encoding="utf-8") as f: json.dump(data, f, indent=4, ensure_ascii=False) - - -class NodeType(Enum): - CLASS = "class_definition" - FUNCTION = "function_definition" - IMPORT = ["import_statement", "import_from_statement"] - IDENTIFIER = "identifier" - ATTRIBUTE = "attribute" - RETURN = "return_statement" - EXPRESSION = "expression_statement" - ASSIGNMENT = "assignment" - - -def traverse_tree(node: Node) -> Generator[Node, None, None]: - """ - Traverse the tree structure starting from the given node. - - :param node: The root node to start the traversal from. - :return: A generator object that yields nodes in the tree. - """ - cursor = node.walk() - depth = 0 - - visited_children = False - while True: - if not visited_children: - yield cursor.node - if not cursor.goto_first_child(): - depth += 1 - visited_children = True - elif cursor.goto_next_sibling(): - visited_children = False - elif not cursor.goto_parent() or depth == 0: - break - else: - depth -= 1 - - -def syntax_check(code, verbose=False): - try: - ast.parse(code) - return True - except (SyntaxError, MemoryError): - if verbose: - traceback.print_exc() - return False - - -def code_extract(text: str) -> str: - lines = text.split("\n") - longest_line_pair = (0, 0) - longest_so_far = 0 - - for i in range(len(lines)): - for j in range(i + 1, len(lines)): - current_lines = "\n".join(lines[i : j + 1]) - if syntax_check(current_lines): - current_length = sum(1 for line in lines[i : j + 1] if line.strip()) - if current_length > longest_so_far: - longest_so_far = current_length - longest_line_pair = (i, j) - - return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1]) - - -def get_definition_name(node: Node) -> str: - for child in node.children: - if child.type == NodeType.IDENTIFIER.value: - return child.text.decode("utf8") - - -def has_return_statement(node: Node) -> bool: - traverse_nodes = traverse_tree(node) - for node in traverse_nodes: - if node.type == NodeType.RETURN.value: - return True - return False - - -def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]: - def dfs_get_deps(node: Node, deps: Set[str]) -> None: - for child in node.children: - if child.type == NodeType.IDENTIFIER.value: - deps.add(child.text.decode("utf8")) - else: - dfs_get_deps(child, deps) - - name2deps = {} - for name, node in nodes: - deps = set() - dfs_get_deps(node, deps) - name2deps[name] = deps - return name2deps - - -def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]: - queue = [entrypoint] - visited = {entrypoint} - while queue: - current = queue.pop(0) - if current not in call_graph: - continue - for neighbour in call_graph[current]: - if neighbour not in visited: - visited.add(neighbour) - queue.append(neighbour) - return visited - - -def sanitize(code: str, entrypoint: Optional[str] = None) -> str: - """ - Sanitize and extract relevant parts of the given Python code. - This function parses the input code, extracts import statements, class and function definitions, - and variable assignments. If an entrypoint is provided, it only includes definitions that are - reachable from the entrypoint in the call graph. - - :param code: The input Python code as a string. - :param entrypoint: Optional name of a function to use as the entrypoint for dependency analysis. - :return: A sanitized version of the input code, containing only relevant parts. - """ - code = code_extract(code) - code_bytes = bytes(code, "utf8") - parser = Parser(Language(tree_sitter_python.language())) - tree = parser.parse(code_bytes) - class_names = set() - function_names = set() - variable_names = set() - - root_node = tree.root_node - import_nodes = [] - definition_nodes = [] - - for child in root_node.children: - if child.type in NodeType.IMPORT.value: - import_nodes.append(child) - elif child.type == NodeType.CLASS.value: - name = get_definition_name(child) - if not (name in class_names or name in variable_names or name in function_names): - definition_nodes.append((name, child)) - class_names.add(name) - elif child.type == NodeType.FUNCTION.value: - name = get_definition_name(child) - if not (name in function_names or name in variable_names or name in class_names) and has_return_statement( - child - ): - definition_nodes.append((name, child)) - function_names.add(get_definition_name(child)) - elif child.type == NodeType.EXPRESSION.value and child.children[0].type == NodeType.ASSIGNMENT.value: - subchild = child.children[0] - name = get_definition_name(subchild) - if not (name in variable_names or name in function_names or name in class_names): - definition_nodes.append((name, subchild)) - variable_names.add(name) - - if entrypoint: - name2deps = get_deps(definition_nodes) - reacheable = get_function_dependency(entrypoint, name2deps) - - sanitized_output = b"" - - for node in import_nodes: - sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" - - for pair in definition_nodes: - name, node = pair - if entrypoint and name not in reacheable: - continue - sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" - return sanitized_output[:-1].decode("utf8") diff --git a/metagpt/ext/aflow/data/download_data.py b/metagpt/ext/aflow/data/download_data.py index 219e2fec7..198ef20c1 100644 --- a/metagpt/ext/aflow/data/download_data.py +++ b/metagpt/ext/aflow/data/download_data.py @@ -12,6 +12,7 @@ from tqdm import tqdm from metagpt.logs import logger + def download_file(url: str, filename: str) -> None: """Download a file from the given URL and show progress.""" response = requests.get(url, stream=True) diff --git a/metagpt/ext/aflow/scripts/operator.py b/metagpt/ext/aflow/scripts/operator.py index 0d71f0a2f..9d27c7cd1 100644 --- a/metagpt/ext/aflow/scripts/operator.py +++ b/metagpt/ext/aflow/scripts/operator.py @@ -13,7 +13,6 @@ from typing import Dict, List, Tuple from tenacity import retry, stop_after_attempt, wait_fixed from metagpt.actions.action_node import ActionNode -from metagpt.logs import logger from metagpt.ext.aflow.scripts.operator_an import ( AnswerGenerateOp, CodeGenerateOp, diff --git a/metagpt/ext/aflow/scripts/utils.py b/metagpt/ext/aflow/scripts/utils.py index f69eaf4ca..36332ffec 100644 --- a/metagpt/ext/aflow/scripts/utils.py +++ b/metagpt/ext/aflow/scripts/utils.py @@ -1,12 +1,19 @@ -# -*- coding: utf-8 -*- -# @Date : 7/2/2024 17:36 PM -# @Author : didi -# @Desc : utils for experiment +""" +@Time : 2024/7/24 16:37 +@Author : didi +@File : utils.py +@Acknowledgement https://github.com/evalplus/evalplus/blob/master/evalplus/sanitize.py +""" import ast import json import re -from typing import Any, List, Tuple +import traceback +from enum import Enum +from typing import Any, Dict, Generator, List, Optional, Set, Tuple + +import tree_sitter_python +from tree_sitter import Language, Node, Parser def extract_task_id(task_id: str) -> int: @@ -161,3 +168,172 @@ def test_check(): test_check() """ return tester_function + + +class NodeType(Enum): + CLASS = "class_definition" + FUNCTION = "function_definition" + IMPORT = ["import_statement", "import_from_statement"] + IDENTIFIER = "identifier" + ATTRIBUTE = "attribute" + RETURN = "return_statement" + EXPRESSION = "expression_statement" + ASSIGNMENT = "assignment" + + +def traverse_tree(node: Node) -> Generator[Node, None, None]: + """ + Traverse the tree structure starting from the given node. + + :param node: The root node to start the traversal from. + :return: A generator object that yields nodes in the tree. + """ + cursor = node.walk() + depth = 0 + + visited_children = False + while True: + if not visited_children: + yield cursor.node + if not cursor.goto_first_child(): + depth += 1 + visited_children = True + elif cursor.goto_next_sibling(): + visited_children = False + elif not cursor.goto_parent() or depth == 0: + break + else: + depth -= 1 + + +def syntax_check(code, verbose=False): + try: + ast.parse(code) + return True + except (SyntaxError, MemoryError): + if verbose: + traceback.print_exc() + return False + + +def code_extract(text: str) -> str: + lines = text.split("\n") + longest_line_pair = (0, 0) + longest_so_far = 0 + + for i in range(len(lines)): + for j in range(i + 1, len(lines)): + current_lines = "\n".join(lines[i : j + 1]) + if syntax_check(current_lines): + current_length = sum(1 for line in lines[i : j + 1] if line.strip()) + if current_length > longest_so_far: + longest_so_far = current_length + longest_line_pair = (i, j) + + return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1]) + + +def get_definition_name(node: Node) -> str: + for child in node.children: + if child.type == NodeType.IDENTIFIER.value: + return child.text.decode("utf8") + + +def has_return_statement(node: Node) -> bool: + traverse_nodes = traverse_tree(node) + for node in traverse_nodes: + if node.type == NodeType.RETURN.value: + return True + return False + + +def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]: + def dfs_get_deps(node: Node, deps: Set[str]) -> None: + for child in node.children: + if child.type == NodeType.IDENTIFIER.value: + deps.add(child.text.decode("utf8")) + else: + dfs_get_deps(child, deps) + + name2deps = {} + for name, node in nodes: + deps = set() + dfs_get_deps(node, deps) + name2deps[name] = deps + return name2deps + + +def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]: + queue = [entrypoint] + visited = {entrypoint} + while queue: + current = queue.pop(0) + if current not in call_graph: + continue + for neighbour in call_graph[current]: + if neighbour not in visited: + visited.add(neighbour) + queue.append(neighbour) + return visited + + +def sanitize(code: str, entrypoint: Optional[str] = None) -> str: + """ + Sanitize and extract relevant parts of the given Python code. + This function parses the input code, extracts import statements, class and function definitions, + and variable assignments. If an entrypoint is provided, it only includes definitions that are + reachable from the entrypoint in the call graph. + + :param code: The input Python code as a string. + :param entrypoint: Optional name of a function to use as the entrypoint for dependency analysis. + :return: A sanitized version of the input code, containing only relevant parts. + """ + code = code_extract(code) + code_bytes = bytes(code, "utf8") + parser = Parser(Language(tree_sitter_python.language())) + tree = parser.parse(code_bytes) + class_names = set() + function_names = set() + variable_names = set() + + root_node = tree.root_node + import_nodes = [] + definition_nodes = [] + + for child in root_node.children: + if child.type in NodeType.IMPORT.value: + import_nodes.append(child) + elif child.type == NodeType.CLASS.value: + name = get_definition_name(child) + if not (name in class_names or name in variable_names or name in function_names): + definition_nodes.append((name, child)) + class_names.add(name) + elif child.type == NodeType.FUNCTION.value: + name = get_definition_name(child) + if not (name in function_names or name in variable_names or name in class_names) and has_return_statement( + child + ): + definition_nodes.append((name, child)) + function_names.add(get_definition_name(child)) + elif child.type == NodeType.EXPRESSION.value and child.children[0].type == NodeType.ASSIGNMENT.value: + subchild = child.children[0] + name = get_definition_name(subchild) + if not (name in variable_names or name in function_names or name in class_names): + definition_nodes.append((name, subchild)) + variable_names.add(name) + + if entrypoint: + name2deps = get_deps(definition_nodes) + reacheable = get_function_dependency(entrypoint, name2deps) + + sanitized_output = b"" + + for node in import_nodes: + sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" + + for pair in definition_nodes: + name, node = pair + if entrypoint and name not in reacheable: + continue + sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" + return sanitized_output[:-1].decode("utf8")