mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-26 17:26:22 +02:00
Update
This commit is contained in:
parent
8c7cde533b
commit
344d87d2e6
11 changed files with 191 additions and 186 deletions
|
|
@ -8,6 +8,7 @@ from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fi
|
|||
from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class DROPBenchmark(BaseBenchmark):
|
||||
def __init__(self, name: str, file_path: str, log_path: str):
|
||||
super().__init__(name, file_path, log_path)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fi
|
|||
from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class GSM8KBenchmark(BaseBenchmark):
|
||||
def __init__(self, name: str, file_path: str, log_path: str):
|
||||
super().__init__(name, file_path, log_path)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fi
|
|||
from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class HotpotQABenchmark(BaseBenchmark):
|
||||
def __init__(self, name: str, file_path: str, log_path: str):
|
||||
super().__init__(name, file_path, log_path)
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
|||
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
|
||||
|
||||
from metagpt.actions.code_sanitize import sanitize
|
||||
from metagpt.logs import logger
|
||||
from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark
|
||||
from metagpt.ext.aflow.scripts.utils import sanitize
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class HumanEvalBenchmark(BaseBenchmark):
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fi
|
|||
from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class MATHBenchmark(BaseBenchmark):
|
||||
def __init__(self, name: str, file_path: str, log_path: str):
|
||||
super().__init__(name, file_path, log_path)
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
|||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
|
||||
|
||||
from metagpt.actions.code_sanitize import sanitize
|
||||
from metagpt.logs import logger
|
||||
from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class MBPPBenchmark(BaseBenchmark):
|
||||
|
|
|
|||
|
|
@ -9,14 +9,8 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
import ast
|
||||
import traceback
|
||||
import numpy as np
|
||||
from enum import Enum
|
||||
from typing import Dict, Generator, List, Optional, Set, Tuple
|
||||
|
||||
import tree_sitter_python
|
||||
from tree_sitter import Language, Node, Parser
|
||||
import numpy as np
|
||||
|
||||
|
||||
def generate_random_indices(n, n_samples, test=False):
|
||||
|
|
@ -82,172 +76,3 @@ def log_mismatch(problem, expected_output, prediction, predicted_number, path):
|
|||
# 将数据写回到log.json文件
|
||||
with open(log_file, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
class NodeType(Enum):
|
||||
CLASS = "class_definition"
|
||||
FUNCTION = "function_definition"
|
||||
IMPORT = ["import_statement", "import_from_statement"]
|
||||
IDENTIFIER = "identifier"
|
||||
ATTRIBUTE = "attribute"
|
||||
RETURN = "return_statement"
|
||||
EXPRESSION = "expression_statement"
|
||||
ASSIGNMENT = "assignment"
|
||||
|
||||
|
||||
def traverse_tree(node: Node) -> Generator[Node, None, None]:
|
||||
"""
|
||||
Traverse the tree structure starting from the given node.
|
||||
|
||||
:param node: The root node to start the traversal from.
|
||||
:return: A generator object that yields nodes in the tree.
|
||||
"""
|
||||
cursor = node.walk()
|
||||
depth = 0
|
||||
|
||||
visited_children = False
|
||||
while True:
|
||||
if not visited_children:
|
||||
yield cursor.node
|
||||
if not cursor.goto_first_child():
|
||||
depth += 1
|
||||
visited_children = True
|
||||
elif cursor.goto_next_sibling():
|
||||
visited_children = False
|
||||
elif not cursor.goto_parent() or depth == 0:
|
||||
break
|
||||
else:
|
||||
depth -= 1
|
||||
|
||||
|
||||
def syntax_check(code, verbose=False):
|
||||
try:
|
||||
ast.parse(code)
|
||||
return True
|
||||
except (SyntaxError, MemoryError):
|
||||
if verbose:
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def code_extract(text: str) -> str:
|
||||
lines = text.split("\n")
|
||||
longest_line_pair = (0, 0)
|
||||
longest_so_far = 0
|
||||
|
||||
for i in range(len(lines)):
|
||||
for j in range(i + 1, len(lines)):
|
||||
current_lines = "\n".join(lines[i : j + 1])
|
||||
if syntax_check(current_lines):
|
||||
current_length = sum(1 for line in lines[i : j + 1] if line.strip())
|
||||
if current_length > longest_so_far:
|
||||
longest_so_far = current_length
|
||||
longest_line_pair = (i, j)
|
||||
|
||||
return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1])
|
||||
|
||||
|
||||
def get_definition_name(node: Node) -> str:
|
||||
for child in node.children:
|
||||
if child.type == NodeType.IDENTIFIER.value:
|
||||
return child.text.decode("utf8")
|
||||
|
||||
|
||||
def has_return_statement(node: Node) -> bool:
|
||||
traverse_nodes = traverse_tree(node)
|
||||
for node in traverse_nodes:
|
||||
if node.type == NodeType.RETURN.value:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]:
|
||||
def dfs_get_deps(node: Node, deps: Set[str]) -> None:
|
||||
for child in node.children:
|
||||
if child.type == NodeType.IDENTIFIER.value:
|
||||
deps.add(child.text.decode("utf8"))
|
||||
else:
|
||||
dfs_get_deps(child, deps)
|
||||
|
||||
name2deps = {}
|
||||
for name, node in nodes:
|
||||
deps = set()
|
||||
dfs_get_deps(node, deps)
|
||||
name2deps[name] = deps
|
||||
return name2deps
|
||||
|
||||
|
||||
def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]:
|
||||
queue = [entrypoint]
|
||||
visited = {entrypoint}
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
if current not in call_graph:
|
||||
continue
|
||||
for neighbour in call_graph[current]:
|
||||
if neighbour not in visited:
|
||||
visited.add(neighbour)
|
||||
queue.append(neighbour)
|
||||
return visited
|
||||
|
||||
|
||||
def sanitize(code: str, entrypoint: Optional[str] = None) -> str:
|
||||
"""
|
||||
Sanitize and extract relevant parts of the given Python code.
|
||||
This function parses the input code, extracts import statements, class and function definitions,
|
||||
and variable assignments. If an entrypoint is provided, it only includes definitions that are
|
||||
reachable from the entrypoint in the call graph.
|
||||
|
||||
:param code: The input Python code as a string.
|
||||
:param entrypoint: Optional name of a function to use as the entrypoint for dependency analysis.
|
||||
:return: A sanitized version of the input code, containing only relevant parts.
|
||||
"""
|
||||
code = code_extract(code)
|
||||
code_bytes = bytes(code, "utf8")
|
||||
parser = Parser(Language(tree_sitter_python.language()))
|
||||
tree = parser.parse(code_bytes)
|
||||
class_names = set()
|
||||
function_names = set()
|
||||
variable_names = set()
|
||||
|
||||
root_node = tree.root_node
|
||||
import_nodes = []
|
||||
definition_nodes = []
|
||||
|
||||
for child in root_node.children:
|
||||
if child.type in NodeType.IMPORT.value:
|
||||
import_nodes.append(child)
|
||||
elif child.type == NodeType.CLASS.value:
|
||||
name = get_definition_name(child)
|
||||
if not (name in class_names or name in variable_names or name in function_names):
|
||||
definition_nodes.append((name, child))
|
||||
class_names.add(name)
|
||||
elif child.type == NodeType.FUNCTION.value:
|
||||
name = get_definition_name(child)
|
||||
if not (name in function_names or name in variable_names or name in class_names) and has_return_statement(
|
||||
child
|
||||
):
|
||||
definition_nodes.append((name, child))
|
||||
function_names.add(get_definition_name(child))
|
||||
elif child.type == NodeType.EXPRESSION.value and child.children[0].type == NodeType.ASSIGNMENT.value:
|
||||
subchild = child.children[0]
|
||||
name = get_definition_name(subchild)
|
||||
if not (name in variable_names or name in function_names or name in class_names):
|
||||
definition_nodes.append((name, subchild))
|
||||
variable_names.add(name)
|
||||
|
||||
if entrypoint:
|
||||
name2deps = get_deps(definition_nodes)
|
||||
reacheable = get_function_dependency(entrypoint, name2deps)
|
||||
|
||||
sanitized_output = b""
|
||||
|
||||
for node in import_nodes:
|
||||
sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
|
||||
|
||||
for pair in definition_nodes:
|
||||
name, node = pair
|
||||
if entrypoint and name not in reacheable:
|
||||
continue
|
||||
sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
|
||||
return sanitized_output[:-1].decode("utf8")
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from tqdm import tqdm
|
|||
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
def download_file(url: str, filename: str) -> None:
|
||||
"""Download a file from the given URL and show progress."""
|
||||
response = requests.get(url, stream=True)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from typing import Dict, List, Tuple
|
|||
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.logs import logger
|
||||
from metagpt.ext.aflow.scripts.operator_an import (
|
||||
AnswerGenerateOp,
|
||||
CodeGenerateOp,
|
||||
|
|
|
|||
|
|
@ -1,12 +1,19 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 7/2/2024 17:36 PM
|
||||
# @Author : didi
|
||||
# @Desc : utils for experiment
|
||||
"""
|
||||
@Time : 2024/7/24 16:37
|
||||
@Author : didi
|
||||
@File : utils.py
|
||||
@Acknowledgement https://github.com/evalplus/evalplus/blob/master/evalplus/sanitize.py
|
||||
"""
|
||||
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
from typing import Any, List, Tuple
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Generator, List, Optional, Set, Tuple
|
||||
|
||||
import tree_sitter_python
|
||||
from tree_sitter import Language, Node, Parser
|
||||
|
||||
|
||||
def extract_task_id(task_id: str) -> int:
|
||||
|
|
@ -161,3 +168,172 @@ def test_check():
|
|||
test_check()
|
||||
"""
|
||||
return tester_function
|
||||
|
||||
|
||||
class NodeType(Enum):
|
||||
CLASS = "class_definition"
|
||||
FUNCTION = "function_definition"
|
||||
IMPORT = ["import_statement", "import_from_statement"]
|
||||
IDENTIFIER = "identifier"
|
||||
ATTRIBUTE = "attribute"
|
||||
RETURN = "return_statement"
|
||||
EXPRESSION = "expression_statement"
|
||||
ASSIGNMENT = "assignment"
|
||||
|
||||
|
||||
def traverse_tree(node: Node) -> Generator[Node, None, None]:
|
||||
"""
|
||||
Traverse the tree structure starting from the given node.
|
||||
|
||||
:param node: The root node to start the traversal from.
|
||||
:return: A generator object that yields nodes in the tree.
|
||||
"""
|
||||
cursor = node.walk()
|
||||
depth = 0
|
||||
|
||||
visited_children = False
|
||||
while True:
|
||||
if not visited_children:
|
||||
yield cursor.node
|
||||
if not cursor.goto_first_child():
|
||||
depth += 1
|
||||
visited_children = True
|
||||
elif cursor.goto_next_sibling():
|
||||
visited_children = False
|
||||
elif not cursor.goto_parent() or depth == 0:
|
||||
break
|
||||
else:
|
||||
depth -= 1
|
||||
|
||||
|
||||
def syntax_check(code, verbose=False):
|
||||
try:
|
||||
ast.parse(code)
|
||||
return True
|
||||
except (SyntaxError, MemoryError):
|
||||
if verbose:
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def code_extract(text: str) -> str:
|
||||
lines = text.split("\n")
|
||||
longest_line_pair = (0, 0)
|
||||
longest_so_far = 0
|
||||
|
||||
for i in range(len(lines)):
|
||||
for j in range(i + 1, len(lines)):
|
||||
current_lines = "\n".join(lines[i : j + 1])
|
||||
if syntax_check(current_lines):
|
||||
current_length = sum(1 for line in lines[i : j + 1] if line.strip())
|
||||
if current_length > longest_so_far:
|
||||
longest_so_far = current_length
|
||||
longest_line_pair = (i, j)
|
||||
|
||||
return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1])
|
||||
|
||||
|
||||
def get_definition_name(node: Node) -> str:
|
||||
for child in node.children:
|
||||
if child.type == NodeType.IDENTIFIER.value:
|
||||
return child.text.decode("utf8")
|
||||
|
||||
|
||||
def has_return_statement(node: Node) -> bool:
|
||||
traverse_nodes = traverse_tree(node)
|
||||
for node in traverse_nodes:
|
||||
if node.type == NodeType.RETURN.value:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]:
|
||||
def dfs_get_deps(node: Node, deps: Set[str]) -> None:
|
||||
for child in node.children:
|
||||
if child.type == NodeType.IDENTIFIER.value:
|
||||
deps.add(child.text.decode("utf8"))
|
||||
else:
|
||||
dfs_get_deps(child, deps)
|
||||
|
||||
name2deps = {}
|
||||
for name, node in nodes:
|
||||
deps = set()
|
||||
dfs_get_deps(node, deps)
|
||||
name2deps[name] = deps
|
||||
return name2deps
|
||||
|
||||
|
||||
def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]:
|
||||
queue = [entrypoint]
|
||||
visited = {entrypoint}
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
if current not in call_graph:
|
||||
continue
|
||||
for neighbour in call_graph[current]:
|
||||
if neighbour not in visited:
|
||||
visited.add(neighbour)
|
||||
queue.append(neighbour)
|
||||
return visited
|
||||
|
||||
|
||||
def sanitize(code: str, entrypoint: Optional[str] = None) -> str:
|
||||
"""
|
||||
Sanitize and extract relevant parts of the given Python code.
|
||||
This function parses the input code, extracts import statements, class and function definitions,
|
||||
and variable assignments. If an entrypoint is provided, it only includes definitions that are
|
||||
reachable from the entrypoint in the call graph.
|
||||
|
||||
:param code: The input Python code as a string.
|
||||
:param entrypoint: Optional name of a function to use as the entrypoint for dependency analysis.
|
||||
:return: A sanitized version of the input code, containing only relevant parts.
|
||||
"""
|
||||
code = code_extract(code)
|
||||
code_bytes = bytes(code, "utf8")
|
||||
parser = Parser(Language(tree_sitter_python.language()))
|
||||
tree = parser.parse(code_bytes)
|
||||
class_names = set()
|
||||
function_names = set()
|
||||
variable_names = set()
|
||||
|
||||
root_node = tree.root_node
|
||||
import_nodes = []
|
||||
definition_nodes = []
|
||||
|
||||
for child in root_node.children:
|
||||
if child.type in NodeType.IMPORT.value:
|
||||
import_nodes.append(child)
|
||||
elif child.type == NodeType.CLASS.value:
|
||||
name = get_definition_name(child)
|
||||
if not (name in class_names or name in variable_names or name in function_names):
|
||||
definition_nodes.append((name, child))
|
||||
class_names.add(name)
|
||||
elif child.type == NodeType.FUNCTION.value:
|
||||
name = get_definition_name(child)
|
||||
if not (name in function_names or name in variable_names or name in class_names) and has_return_statement(
|
||||
child
|
||||
):
|
||||
definition_nodes.append((name, child))
|
||||
function_names.add(get_definition_name(child))
|
||||
elif child.type == NodeType.EXPRESSION.value and child.children[0].type == NodeType.ASSIGNMENT.value:
|
||||
subchild = child.children[0]
|
||||
name = get_definition_name(subchild)
|
||||
if not (name in variable_names or name in function_names or name in class_names):
|
||||
definition_nodes.append((name, subchild))
|
||||
variable_names.add(name)
|
||||
|
||||
if entrypoint:
|
||||
name2deps = get_deps(definition_nodes)
|
||||
reacheable = get_function_dependency(entrypoint, name2deps)
|
||||
|
||||
sanitized_output = b""
|
||||
|
||||
for node in import_nodes:
|
||||
sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
|
||||
|
||||
for pair in definition_nodes:
|
||||
name, node = pair
|
||||
if entrypoint and name not in reacheable:
|
||||
continue
|
||||
sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
|
||||
return sanitized_output[:-1].decode("utf8")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue