From 8c7cde533b92312a03dd213c5123bd82cac0e392 Mon Sep 17 00:00:00 2001 From: didi <84363704+didiforgithub@users.noreply.github.com> Date: Tue, 22 Oct 2024 12:55:18 +0800 Subject: [PATCH] Transform print into logger.info & mv code sanitize to utils.py --- metagpt/actions/action_node.py | 2 +- metagpt/actions/code_sanitize.py | 184 ----------------------- metagpt/ext/aflow/benchmark/drop.py | 4 +- metagpt/ext/aflow/benchmark/gsm8k.py | 2 +- metagpt/ext/aflow/benchmark/hotpotqa.py | 2 +- metagpt/ext/aflow/benchmark/math.py | 2 +- metagpt/ext/aflow/benchmark/utils.py | 186 +++++++++++++++++++++++- metagpt/ext/aflow/scripts/operator.py | 4 +- 8 files changed, 193 insertions(+), 193 deletions(-) delete mode 100644 metagpt/actions/code_sanitize.py diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 1909d3835..6648b6e05 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -18,7 +18,7 @@ from pydantic import BaseModel, Field, create_model, model_validator from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action_outcls_registry import register_action_outcls -from metagpt.actions.code_sanitize import sanitize +from metagpt.ext.aflow.benchmark.utils import sanitize from metagpt.const import USE_CONFIG_TIMEOUT from metagpt.llm import BaseLLM from metagpt.logs import logger diff --git a/metagpt/actions/code_sanitize.py b/metagpt/actions/code_sanitize.py deleted file mode 100644 index 56422589c..000000000 --- a/metagpt/actions/code_sanitize.py +++ /dev/null @@ -1,184 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2024/7/24 16:37 -@Author : didi -@File : code_node.py -@Acknowledgement https://github.com/evalplus/evalplus/blob/master/evalplus/sanitize.py -""" -import ast -import traceback -from enum import Enum -from typing import Dict, Generator, List, Optional, Set, Tuple - -import tree_sitter_python -from tree_sitter import Language, Node, Parser - - -class NodeType(Enum): - CLASS = "class_definition" - FUNCTION = "function_definition" - IMPORT = ["import_statement", "import_from_statement"] - IDENTIFIER = "identifier" - ATTRIBUTE = "attribute" - RETURN = "return_statement" - EXPRESSION = "expression_statement" - ASSIGNMENT = "assignment" - - -def traverse_tree(node: Node) -> Generator[Node, None, None]: - """ - Traverse the tree structure starting from the given node. - - :param node: The root node to start the traversal from. - :return: A generator object that yields nodes in the tree. - """ - cursor = node.walk() - depth = 0 - - visited_children = False - while True: - if not visited_children: - yield cursor.node - if not cursor.goto_first_child(): - depth += 1 - visited_children = True - elif cursor.goto_next_sibling(): - visited_children = False - elif not cursor.goto_parent() or depth == 0: - break - else: - depth -= 1 - - -def syntax_check(code, verbose=False): - try: - ast.parse(code) - return True - except (SyntaxError, MemoryError): - if verbose: - traceback.print_exc() - return False - - -def code_extract(text: str) -> str: - lines = text.split("\n") - longest_line_pair = (0, 0) - longest_so_far = 0 - - for i in range(len(lines)): - for j in range(i + 1, len(lines)): - current_lines = "\n".join(lines[i : j + 1]) - if syntax_check(current_lines): - current_length = sum(1 for line in lines[i : j + 1] if line.strip()) - if current_length > longest_so_far: - longest_so_far = current_length - longest_line_pair = (i, j) - - return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1]) - - -def get_definition_name(node: Node) -> str: - for child in node.children: - if child.type == NodeType.IDENTIFIER.value: - return child.text.decode("utf8") - - -def has_return_statement(node: Node) -> bool: - traverse_nodes = traverse_tree(node) - for node in traverse_nodes: - if node.type == NodeType.RETURN.value: - return True - return False - - -def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]: - def dfs_get_deps(node: Node, deps: Set[str]) -> None: - for child in node.children: - if child.type == NodeType.IDENTIFIER.value: - deps.add(child.text.decode("utf8")) - else: - dfs_get_deps(child, deps) - - name2deps = {} - for name, node in nodes: - deps = set() - dfs_get_deps(node, deps) - name2deps[name] = deps - return name2deps - - -def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]: - queue = [entrypoint] - visited = {entrypoint} - while queue: - current = queue.pop(0) - if current not in call_graph: - continue - for neighbour in call_graph[current]: - if neighbour not in visited: - visited.add(neighbour) - queue.append(neighbour) - return visited - - -def sanitize(code: str, entrypoint: Optional[str] = None) -> str: - """ - Sanitize and extract relevant parts of the given Python code. - This function parses the input code, extracts import statements, class and function definitions, - and variable assignments. If an entrypoint is provided, it only includes definitions that are - reachable from the entrypoint in the call graph. - - :param code: The input Python code as a string. - :param entrypoint: Optional name of a function to use as the entrypoint for dependency analysis. - :return: A sanitized version of the input code, containing only relevant parts. - """ - code = code_extract(code) - code_bytes = bytes(code, "utf8") - parser = Parser(Language(tree_sitter_python.language())) - tree = parser.parse(code_bytes) - class_names = set() - function_names = set() - variable_names = set() - - root_node = tree.root_node - import_nodes = [] - definition_nodes = [] - - for child in root_node.children: - if child.type in NodeType.IMPORT.value: - import_nodes.append(child) - elif child.type == NodeType.CLASS.value: - name = get_definition_name(child) - if not (name in class_names or name in variable_names or name in function_names): - definition_nodes.append((name, child)) - class_names.add(name) - elif child.type == NodeType.FUNCTION.value: - name = get_definition_name(child) - if not (name in function_names or name in variable_names or name in class_names) and has_return_statement( - child - ): - definition_nodes.append((name, child)) - function_names.add(get_definition_name(child)) - elif child.type == NodeType.EXPRESSION.value and child.children[0].type == NodeType.ASSIGNMENT.value: - subchild = child.children[0] - name = get_definition_name(subchild) - if not (name in variable_names or name in function_names or name in class_names): - definition_nodes.append((name, subchild)) - variable_names.add(name) - - if entrypoint: - name2deps = get_deps(definition_nodes) - reacheable = get_function_dependency(entrypoint, name2deps) - - sanitized_output = b"" - - for node in import_nodes: - sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" - - for pair in definition_nodes: - name, node = pair - if entrypoint and name not in reacheable: - continue - sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" - return sanitized_output[:-1].decode("utf8") diff --git a/metagpt/ext/aflow/benchmark/drop.py b/metagpt/ext/aflow/benchmark/drop.py index 7963dcd45..a3aa40740 100644 --- a/metagpt/ext/aflow/benchmark/drop.py +++ b/metagpt/ext/aflow/benchmark/drop.py @@ -6,7 +6,7 @@ from typing import Callable, List, Tuple from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark - +from metagpt.logs import logger class DROPBenchmark(BaseBenchmark): def __init__(self, name: str, file_path: str, log_path: str): @@ -75,7 +75,7 @@ class DROPBenchmark(BaseBenchmark): return input_text, output, expected_output, uni_score, cost except Exception as e: - print(f"Maximum retries reached. Skipping this sample. Error: {e}") + logger.info(f"Maximum retries reached. Skipping this sample. Error: {e}") return input_text, str(e), expected_output, 0.0, 0.0 def get_result_columns(self) -> List[str]: diff --git a/metagpt/ext/aflow/benchmark/gsm8k.py b/metagpt/ext/aflow/benchmark/gsm8k.py index 5ecff8c7f..898b3cc4f 100644 --- a/metagpt/ext/aflow/benchmark/gsm8k.py +++ b/metagpt/ext/aflow/benchmark/gsm8k.py @@ -49,7 +49,7 @@ class GSM8KBenchmark(BaseBenchmark): return input_text, output, expected_output, score, cost except Exception as e: - print(f"Maximum retries reached. Skipping this sample. Error: {e}") + logger.info(f"Maximum retries reached. Skipping this sample. Error: {e}") return input_text, str(e), expected_output, 0.0, 0.0 def get_result_columns(self) -> List[str]: diff --git a/metagpt/ext/aflow/benchmark/hotpotqa.py b/metagpt/ext/aflow/benchmark/hotpotqa.py index 85c15440e..0086fdd04 100644 --- a/metagpt/ext/aflow/benchmark/hotpotqa.py +++ b/metagpt/ext/aflow/benchmark/hotpotqa.py @@ -63,7 +63,7 @@ class HotpotQABenchmark(BaseBenchmark): return input_text, context_str, output, expected_output, score, cost except Exception as e: - print(f"Maximum retries reached. Skipping this sample. Error: {e}") + logger.info(f"Maximum retries reached. Skipping this sample. Error: {e}") return input_text, context_str, str(e), expected_output, 0.0, 0.0 def get_result_columns(self) -> List[str]: diff --git a/metagpt/ext/aflow/benchmark/math.py b/metagpt/ext/aflow/benchmark/math.py index edc23c347..3db5e9596 100644 --- a/metagpt/ext/aflow/benchmark/math.py +++ b/metagpt/ext/aflow/benchmark/math.py @@ -115,7 +115,7 @@ class MATHBenchmark(BaseBenchmark): return input_text, output, expected_output, uni_score, cost except Exception as e: - print(f"Maximum retries reached. Skipping this sample. Error: {e}") + logger.info(f"Maximum retries reached. Skipping this sample. Error: {e}") return input_text, str(e), expected_output, 0.0, 0.0 def get_result_columns(self) -> List[str]: diff --git a/metagpt/ext/aflow/benchmark/utils.py b/metagpt/ext/aflow/benchmark/utils.py index e620a52a3..37d655d51 100644 --- a/metagpt/ext/aflow/benchmark/utils.py +++ b/metagpt/ext/aflow/benchmark/utils.py @@ -1,7 +1,22 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/7/24 16:37 +@Author : didi +@File : utils.py +@Acknowledgement https://github.com/evalplus/evalplus/blob/master/evalplus/sanitize.py +""" + import json import os - +import ast +import traceback import numpy as np +from enum import Enum +from typing import Dict, Generator, List, Optional, Set, Tuple + +import tree_sitter_python +from tree_sitter import Language, Node, Parser def generate_random_indices(n, n_samples, test=False): @@ -67,3 +82,172 @@ def log_mismatch(problem, expected_output, prediction, predicted_number, path): # 将数据写回到log.json文件 with open(log_file, "w", encoding="utf-8") as f: json.dump(data, f, indent=4, ensure_ascii=False) + + +class NodeType(Enum): + CLASS = "class_definition" + FUNCTION = "function_definition" + IMPORT = ["import_statement", "import_from_statement"] + IDENTIFIER = "identifier" + ATTRIBUTE = "attribute" + RETURN = "return_statement" + EXPRESSION = "expression_statement" + ASSIGNMENT = "assignment" + + +def traverse_tree(node: Node) -> Generator[Node, None, None]: + """ + Traverse the tree structure starting from the given node. + + :param node: The root node to start the traversal from. + :return: A generator object that yields nodes in the tree. + """ + cursor = node.walk() + depth = 0 + + visited_children = False + while True: + if not visited_children: + yield cursor.node + if not cursor.goto_first_child(): + depth += 1 + visited_children = True + elif cursor.goto_next_sibling(): + visited_children = False + elif not cursor.goto_parent() or depth == 0: + break + else: + depth -= 1 + + +def syntax_check(code, verbose=False): + try: + ast.parse(code) + return True + except (SyntaxError, MemoryError): + if verbose: + traceback.print_exc() + return False + + +def code_extract(text: str) -> str: + lines = text.split("\n") + longest_line_pair = (0, 0) + longest_so_far = 0 + + for i in range(len(lines)): + for j in range(i + 1, len(lines)): + current_lines = "\n".join(lines[i : j + 1]) + if syntax_check(current_lines): + current_length = sum(1 for line in lines[i : j + 1] if line.strip()) + if current_length > longest_so_far: + longest_so_far = current_length + longest_line_pair = (i, j) + + return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1]) + + +def get_definition_name(node: Node) -> str: + for child in node.children: + if child.type == NodeType.IDENTIFIER.value: + return child.text.decode("utf8") + + +def has_return_statement(node: Node) -> bool: + traverse_nodes = traverse_tree(node) + for node in traverse_nodes: + if node.type == NodeType.RETURN.value: + return True + return False + + +def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]: + def dfs_get_deps(node: Node, deps: Set[str]) -> None: + for child in node.children: + if child.type == NodeType.IDENTIFIER.value: + deps.add(child.text.decode("utf8")) + else: + dfs_get_deps(child, deps) + + name2deps = {} + for name, node in nodes: + deps = set() + dfs_get_deps(node, deps) + name2deps[name] = deps + return name2deps + + +def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]: + queue = [entrypoint] + visited = {entrypoint} + while queue: + current = queue.pop(0) + if current not in call_graph: + continue + for neighbour in call_graph[current]: + if neighbour not in visited: + visited.add(neighbour) + queue.append(neighbour) + return visited + + +def sanitize(code: str, entrypoint: Optional[str] = None) -> str: + """ + Sanitize and extract relevant parts of the given Python code. + This function parses the input code, extracts import statements, class and function definitions, + and variable assignments. If an entrypoint is provided, it only includes definitions that are + reachable from the entrypoint in the call graph. + + :param code: The input Python code as a string. + :param entrypoint: Optional name of a function to use as the entrypoint for dependency analysis. + :return: A sanitized version of the input code, containing only relevant parts. + """ + code = code_extract(code) + code_bytes = bytes(code, "utf8") + parser = Parser(Language(tree_sitter_python.language())) + tree = parser.parse(code_bytes) + class_names = set() + function_names = set() + variable_names = set() + + root_node = tree.root_node + import_nodes = [] + definition_nodes = [] + + for child in root_node.children: + if child.type in NodeType.IMPORT.value: + import_nodes.append(child) + elif child.type == NodeType.CLASS.value: + name = get_definition_name(child) + if not (name in class_names or name in variable_names or name in function_names): + definition_nodes.append((name, child)) + class_names.add(name) + elif child.type == NodeType.FUNCTION.value: + name = get_definition_name(child) + if not (name in function_names or name in variable_names or name in class_names) and has_return_statement( + child + ): + definition_nodes.append((name, child)) + function_names.add(get_definition_name(child)) + elif child.type == NodeType.EXPRESSION.value and child.children[0].type == NodeType.ASSIGNMENT.value: + subchild = child.children[0] + name = get_definition_name(subchild) + if not (name in variable_names or name in function_names or name in class_names): + definition_nodes.append((name, subchild)) + variable_names.add(name) + + if entrypoint: + name2deps = get_deps(definition_nodes) + reacheable = get_function_dependency(entrypoint, name2deps) + + sanitized_output = b"" + + for node in import_nodes: + sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" + + for pair in definition_nodes: + name, node = pair + if entrypoint and name not in reacheable: + continue + sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" + return sanitized_output[:-1].decode("utf8") diff --git a/metagpt/ext/aflow/scripts/operator.py b/metagpt/ext/aflow/scripts/operator.py index bfd875b26..0d71f0a2f 100644 --- a/metagpt/ext/aflow/scripts/operator.py +++ b/metagpt/ext/aflow/scripts/operator.py @@ -207,7 +207,7 @@ class Programmer(Operator): if status == "Success": return {"code": code, "output": output} else: - print(f"Execution error on attempt {i + 1}, error message: {output}") + logger.info(f"Execution error on attempt {i + 1}, error message: {output}") feedback = ( f"\nThe result of the error from the code you wrote in the previous round:\n" f"Code: {code}\n\nStatus: {status}, {output}" @@ -336,7 +336,7 @@ class MdEnsemble(Operator): return shuffled_solutions, answer_mapping async def __call__(self, solutions: List[str], problem: str, mode: str = None): - print(f"solution count: {len(solutions)}") + logger.info(f"solution count: {len(solutions)}") all_responses = [] for _ in range(self.vote_count):