This commit is contained in:
didi 2024-10-22 18:14:15 +08:00
parent 8c7cde533b
commit 344d87d2e6
11 changed files with 191 additions and 186 deletions

View file

@ -8,6 +8,7 @@ from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fi
from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark
from metagpt.logs import logger
class DROPBenchmark(BaseBenchmark):
def __init__(self, name: str, file_path: str, log_path: str):
super().__init__(name, file_path, log_path)

View file

@ -10,6 +10,7 @@ from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fi
from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark
from metagpt.logs import logger
class GSM8KBenchmark(BaseBenchmark):
def __init__(self, name: str, file_path: str, log_path: str):
super().__init__(name, file_path, log_path)

View file

@ -8,6 +8,7 @@ from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fi
from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark
from metagpt.logs import logger
class HotpotQABenchmark(BaseBenchmark):
def __init__(self, name: str, file_path: str, log_path: str):
super().__init__(name, file_path, log_path)

View file

@ -5,9 +5,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
from metagpt.actions.code_sanitize import sanitize
from metagpt.logs import logger
from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark
from metagpt.ext.aflow.scripts.utils import sanitize
from metagpt.logs import logger
class HumanEvalBenchmark(BaseBenchmark):

View file

@ -11,6 +11,7 @@ from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fi
from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark
from metagpt.logs import logger
class MATHBenchmark(BaseBenchmark):
def __init__(self, name: str, file_path: str, log_path: str):
super().__init__(name, file_path, log_path)

View file

@ -5,8 +5,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
from metagpt.actions.code_sanitize import sanitize
from metagpt.logs import logger
from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark
from metagpt.logs import logger
class MBPPBenchmark(BaseBenchmark):

View file

@ -9,14 +9,8 @@
import json
import os
import ast
import traceback
import numpy as np
from enum import Enum
from typing import Dict, Generator, List, Optional, Set, Tuple
import tree_sitter_python
from tree_sitter import Language, Node, Parser
import numpy as np
def generate_random_indices(n, n_samples, test=False):
@ -82,172 +76,3 @@ def log_mismatch(problem, expected_output, prediction, predicted_number, path):
# 将数据写回到log.json文件
with open(log_file, "w", encoding="utf-8") as f:
json.dump(data, f, indent=4, ensure_ascii=False)
class NodeType(Enum):
CLASS = "class_definition"
FUNCTION = "function_definition"
IMPORT = ["import_statement", "import_from_statement"]
IDENTIFIER = "identifier"
ATTRIBUTE = "attribute"
RETURN = "return_statement"
EXPRESSION = "expression_statement"
ASSIGNMENT = "assignment"
def traverse_tree(node: Node) -> Generator[Node, None, None]:
"""
Traverse the tree structure starting from the given node.
:param node: The root node to start the traversal from.
:return: A generator object that yields nodes in the tree.
"""
cursor = node.walk()
depth = 0
visited_children = False
while True:
if not visited_children:
yield cursor.node
if not cursor.goto_first_child():
depth += 1
visited_children = True
elif cursor.goto_next_sibling():
visited_children = False
elif not cursor.goto_parent() or depth == 0:
break
else:
depth -= 1
def syntax_check(code, verbose=False):
try:
ast.parse(code)
return True
except (SyntaxError, MemoryError):
if verbose:
traceback.print_exc()
return False
def code_extract(text: str) -> str:
lines = text.split("\n")
longest_line_pair = (0, 0)
longest_so_far = 0
for i in range(len(lines)):
for j in range(i + 1, len(lines)):
current_lines = "\n".join(lines[i : j + 1])
if syntax_check(current_lines):
current_length = sum(1 for line in lines[i : j + 1] if line.strip())
if current_length > longest_so_far:
longest_so_far = current_length
longest_line_pair = (i, j)
return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1])
def get_definition_name(node: Node) -> str:
for child in node.children:
if child.type == NodeType.IDENTIFIER.value:
return child.text.decode("utf8")
def has_return_statement(node: Node) -> bool:
traverse_nodes = traverse_tree(node)
for node in traverse_nodes:
if node.type == NodeType.RETURN.value:
return True
return False
def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]:
def dfs_get_deps(node: Node, deps: Set[str]) -> None:
for child in node.children:
if child.type == NodeType.IDENTIFIER.value:
deps.add(child.text.decode("utf8"))
else:
dfs_get_deps(child, deps)
name2deps = {}
for name, node in nodes:
deps = set()
dfs_get_deps(node, deps)
name2deps[name] = deps
return name2deps
def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]:
queue = [entrypoint]
visited = {entrypoint}
while queue:
current = queue.pop(0)
if current not in call_graph:
continue
for neighbour in call_graph[current]:
if neighbour not in visited:
visited.add(neighbour)
queue.append(neighbour)
return visited
def sanitize(code: str, entrypoint: Optional[str] = None) -> str:
"""
Sanitize and extract relevant parts of the given Python code.
This function parses the input code, extracts import statements, class and function definitions,
and variable assignments. If an entrypoint is provided, it only includes definitions that are
reachable from the entrypoint in the call graph.
:param code: The input Python code as a string.
:param entrypoint: Optional name of a function to use as the entrypoint for dependency analysis.
:return: A sanitized version of the input code, containing only relevant parts.
"""
code = code_extract(code)
code_bytes = bytes(code, "utf8")
parser = Parser(Language(tree_sitter_python.language()))
tree = parser.parse(code_bytes)
class_names = set()
function_names = set()
variable_names = set()
root_node = tree.root_node
import_nodes = []
definition_nodes = []
for child in root_node.children:
if child.type in NodeType.IMPORT.value:
import_nodes.append(child)
elif child.type == NodeType.CLASS.value:
name = get_definition_name(child)
if not (name in class_names or name in variable_names or name in function_names):
definition_nodes.append((name, child))
class_names.add(name)
elif child.type == NodeType.FUNCTION.value:
name = get_definition_name(child)
if not (name in function_names or name in variable_names or name in class_names) and has_return_statement(
child
):
definition_nodes.append((name, child))
function_names.add(get_definition_name(child))
elif child.type == NodeType.EXPRESSION.value and child.children[0].type == NodeType.ASSIGNMENT.value:
subchild = child.children[0]
name = get_definition_name(subchild)
if not (name in variable_names or name in function_names or name in class_names):
definition_nodes.append((name, subchild))
variable_names.add(name)
if entrypoint:
name2deps = get_deps(definition_nodes)
reacheable = get_function_dependency(entrypoint, name2deps)
sanitized_output = b""
for node in import_nodes:
sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
for pair in definition_nodes:
name, node = pair
if entrypoint and name not in reacheable:
continue
sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
return sanitized_output[:-1].decode("utf8")

View file

@ -12,6 +12,7 @@ from tqdm import tqdm
from metagpt.logs import logger
def download_file(url: str, filename: str) -> None:
"""Download a file from the given URL and show progress."""
response = requests.get(url, stream=True)

View file

@ -13,7 +13,6 @@ from typing import Dict, List, Tuple
from tenacity import retry, stop_after_attempt, wait_fixed
from metagpt.actions.action_node import ActionNode
from metagpt.logs import logger
from metagpt.ext.aflow.scripts.operator_an import (
AnswerGenerateOp,
CodeGenerateOp,

View file

@ -1,12 +1,19 @@
# -*- coding: utf-8 -*-
# @Date : 7/2/2024 17:36 PM
# @Author : didi
# @Desc : utils for experiment
"""
@Time : 2024/7/24 16:37
@Author : didi
@File : utils.py
@Acknowledgement https://github.com/evalplus/evalplus/blob/master/evalplus/sanitize.py
"""
import ast
import json
import re
from typing import Any, List, Tuple
import traceback
from enum import Enum
from typing import Any, Dict, Generator, List, Optional, Set, Tuple
import tree_sitter_python
from tree_sitter import Language, Node, Parser
def extract_task_id(task_id: str) -> int:
@ -161,3 +168,172 @@ def test_check():
test_check()
"""
return tester_function
class NodeType(Enum):
CLASS = "class_definition"
FUNCTION = "function_definition"
IMPORT = ["import_statement", "import_from_statement"]
IDENTIFIER = "identifier"
ATTRIBUTE = "attribute"
RETURN = "return_statement"
EXPRESSION = "expression_statement"
ASSIGNMENT = "assignment"
def traverse_tree(node: Node) -> Generator[Node, None, None]:
"""
Traverse the tree structure starting from the given node.
:param node: The root node to start the traversal from.
:return: A generator object that yields nodes in the tree.
"""
cursor = node.walk()
depth = 0
visited_children = False
while True:
if not visited_children:
yield cursor.node
if not cursor.goto_first_child():
depth += 1
visited_children = True
elif cursor.goto_next_sibling():
visited_children = False
elif not cursor.goto_parent() or depth == 0:
break
else:
depth -= 1
def syntax_check(code, verbose=False):
try:
ast.parse(code)
return True
except (SyntaxError, MemoryError):
if verbose:
traceback.print_exc()
return False
def code_extract(text: str) -> str:
lines = text.split("\n")
longest_line_pair = (0, 0)
longest_so_far = 0
for i in range(len(lines)):
for j in range(i + 1, len(lines)):
current_lines = "\n".join(lines[i : j + 1])
if syntax_check(current_lines):
current_length = sum(1 for line in lines[i : j + 1] if line.strip())
if current_length > longest_so_far:
longest_so_far = current_length
longest_line_pair = (i, j)
return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1])
def get_definition_name(node: Node) -> str:
for child in node.children:
if child.type == NodeType.IDENTIFIER.value:
return child.text.decode("utf8")
def has_return_statement(node: Node) -> bool:
traverse_nodes = traverse_tree(node)
for node in traverse_nodes:
if node.type == NodeType.RETURN.value:
return True
return False
def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]:
def dfs_get_deps(node: Node, deps: Set[str]) -> None:
for child in node.children:
if child.type == NodeType.IDENTIFIER.value:
deps.add(child.text.decode("utf8"))
else:
dfs_get_deps(child, deps)
name2deps = {}
for name, node in nodes:
deps = set()
dfs_get_deps(node, deps)
name2deps[name] = deps
return name2deps
def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]:
queue = [entrypoint]
visited = {entrypoint}
while queue:
current = queue.pop(0)
if current not in call_graph:
continue
for neighbour in call_graph[current]:
if neighbour not in visited:
visited.add(neighbour)
queue.append(neighbour)
return visited
def sanitize(code: str, entrypoint: Optional[str] = None) -> str:
"""
Sanitize and extract relevant parts of the given Python code.
This function parses the input code, extracts import statements, class and function definitions,
and variable assignments. If an entrypoint is provided, it only includes definitions that are
reachable from the entrypoint in the call graph.
:param code: The input Python code as a string.
:param entrypoint: Optional name of a function to use as the entrypoint for dependency analysis.
:return: A sanitized version of the input code, containing only relevant parts.
"""
code = code_extract(code)
code_bytes = bytes(code, "utf8")
parser = Parser(Language(tree_sitter_python.language()))
tree = parser.parse(code_bytes)
class_names = set()
function_names = set()
variable_names = set()
root_node = tree.root_node
import_nodes = []
definition_nodes = []
for child in root_node.children:
if child.type in NodeType.IMPORT.value:
import_nodes.append(child)
elif child.type == NodeType.CLASS.value:
name = get_definition_name(child)
if not (name in class_names or name in variable_names or name in function_names):
definition_nodes.append((name, child))
class_names.add(name)
elif child.type == NodeType.FUNCTION.value:
name = get_definition_name(child)
if not (name in function_names or name in variable_names or name in class_names) and has_return_statement(
child
):
definition_nodes.append((name, child))
function_names.add(get_definition_name(child))
elif child.type == NodeType.EXPRESSION.value and child.children[0].type == NodeType.ASSIGNMENT.value:
subchild = child.children[0]
name = get_definition_name(subchild)
if not (name in variable_names or name in function_names or name in class_names):
definition_nodes.append((name, subchild))
variable_names.add(name)
if entrypoint:
name2deps = get_deps(definition_nodes)
reacheable = get_function_dependency(entrypoint, name2deps)
sanitized_output = b""
for node in import_nodes:
sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
for pair in definition_nodes:
name, node = pair
if entrypoint and name not in reacheable:
continue
sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
return sanitized_output[:-1].decode("utf8")