Update Latest Review

This commit is contained in:
didi 2024-10-28 21:23:46 +08:00
parent 92e520ded2
commit f0a3a3f739
15 changed files with 276 additions and 348 deletions

View file

@ -3,6 +3,7 @@ import json
import os
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, List, Tuple
import aiofiles
@ -10,6 +11,7 @@ import pandas as pd
from tqdm.asyncio import tqdm_asyncio
from metagpt.logs import logger
from metagpt.utils.common import write_json_file
class BaseBenchmark(ABC):
@ -18,6 +20,9 @@ class BaseBenchmark(ABC):
self.file_path = file_path
self.log_path = log_path
PASS = "PASS"
FAIL = "FAIL"
async def load_data(self, specific_indices: List[int] = None) -> List[dict]:
data = []
async with aiofiles.open(self.file_path, mode="r", encoding="utf-8") as file:
@ -55,9 +60,9 @@ class BaseBenchmark(ABC):
"extracted_output": extracted_output,
"extract_answer_code": extract_answer_code,
}
log_file = os.path.join(self.log_path, "log.json")
if os.path.exists(log_file):
with open(log_file, "r", encoding="utf-8") as f:
log_file = Path(self.log_path) / "log.json"
if log_file.exists():
with log_file.open("r", encoding="utf-8") as f:
try:
data = json.load(f)
except json.JSONDecodeError:
@ -65,8 +70,7 @@ class BaseBenchmark(ABC):
else:
data = []
data.append(log_data)
with open(log_file, "w", encoding="utf-8") as f:
json.dump(data, f, indent=4, ensure_ascii=False)
write_json_file(log_file, data, encoding="utf-8", indent=4)
@abstractmethod
async def evaluate_problem(self, problem: dict, graph: Callable) -> Tuple[Any, ...]:

View file

@ -6,17 +6,14 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark
from metagpt.ext.aflow.scripts.utils import sanitize
from metagpt.logs import logger
from metagpt.utils.sanitize import sanitize
class HumanEvalBenchmark(BaseBenchmark):
def __init__(self, name: str, file_path: str, log_path: str):
super().__init__(name, file_path, log_path)
PASS = "PASS"
FAIL = "FAIL"
class TimeoutError(Exception):
pass

View file

@ -5,17 +5,14 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark
from metagpt.ext.aflow.scripts.utils import sanitize
from metagpt.logs import logger
from metagpt.utils.sanitize import sanitize
class MBPPBenchmark(BaseBenchmark):
def __init__(self, name: str, file_path: str, log_path: str):
super().__init__(name, file_path, log_path)
PASS = "PASS"
FAIL = "FAIL"
class TimeoutError(Exception):
pass

View file

@ -4,7 +4,6 @@
@Time : 2024/7/24 16:37
@Author : didi
@File : utils.py
@Acknowledgement https://github.com/evalplus/evalplus/blob/master/evalplus/sanitize.py
"""
import json
@ -12,6 +11,8 @@ import os
import numpy as np
from metagpt.utils.common import write_json_file
def generate_random_indices(n, n_samples, test=False):
"""
@ -41,13 +42,6 @@ def split_data_set(file_path, samples, test=False):
return data
# save data into a jsonl file
def save_data(data, file_path):
with open(file_path, "w") as file:
for d in data:
file.write(json.dumps(d) + "\n")
def log_mismatch(problem, expected_output, prediction, predicted_number, path):
log_data = {
"question": problem,
@ -74,5 +68,4 @@ def log_mismatch(problem, expected_output, prediction, predicted_number, path):
data.append(log_data)
# 将数据写回到log.json文件
with open(log_file, "w", encoding="utf-8") as f:
json.dump(data, f, indent=4, ensure_ascii=False)
write_json_file(log_file, data, encoding="utf-8", indent=4)

View file

@ -2,62 +2,11 @@
@Time : 2024/7/24 16:37
@Author : didi
@File : utils.py
@Acknowledgement https://github.com/evalplus/evalplus/blob/master/evalplus/sanitize.py
"""
import ast
import json
import re
import traceback
from enum import Enum
from typing import Any, Dict, Generator, List, Optional, Set, Tuple
import tree_sitter_python
from tree_sitter import Language, Node, Parser
def extract_task_id(task_id: str) -> int:
"""Extract the numeric part of the task_id."""
match = re.search(r"/(\d+)", task_id)
return int(match.group(1)) if match else 0
def get_hotpotqa(path: str):
# Parses each jsonl line and yields it as a dictionary
def parse_jsonl(path):
with open(path) as f:
for line in f:
yield json.loads(line)
datas = list(parse_jsonl(path))
return {data["_id"]: data for data in datas}
def sort_json_by_key(input_file: str, output_file: str, key: str = "task_id"):
"""
Read a JSONL file, sort the entries based on task_id, and write to a new JSONL file.
:param input_file: Path to the input JSONL file
:param output_file: Path to the output JSONL file
"""
# Read and parse the JSONL file
with open(input_file, "r") as f:
data = [json.loads(line) for line in f]
# Sort the data based on the numeric part of task_id
sorted_data = sorted(data, key=lambda x: extract_task_id(x[key]))
# Write the sorted data to a new JSONL file
with open(output_file, "w") as f:
for item in sorted_data:
f.write(json.dumps(item) + "\n")
def parse_python_literal(s):
try:
return ast.literal_eval(s)
except (ValueError, SyntaxError):
return s
from typing import Any, List, Tuple
def extract_test_cases_from_jsonl(entry_point: str, dataset: str = "HumanEval"):
@ -168,172 +117,3 @@ def test_check():
test_check()
"""
return tester_function
class NodeType(Enum):
CLASS = "class_definition"
FUNCTION = "function_definition"
IMPORT = ["import_statement", "import_from_statement"]
IDENTIFIER = "identifier"
ATTRIBUTE = "attribute"
RETURN = "return_statement"
EXPRESSION = "expression_statement"
ASSIGNMENT = "assignment"
def traverse_tree(node: Node) -> Generator[Node, None, None]:
"""
Traverse the tree structure starting from the given node.
:param node: The root node to start the traversal from.
:return: A generator object that yields nodes in the tree.
"""
cursor = node.walk()
depth = 0
visited_children = False
while True:
if not visited_children:
yield cursor.node
if not cursor.goto_first_child():
depth += 1
visited_children = True
elif cursor.goto_next_sibling():
visited_children = False
elif not cursor.goto_parent() or depth == 0:
break
else:
depth -= 1
def syntax_check(code, verbose=False):
try:
ast.parse(code)
return True
except (SyntaxError, MemoryError):
if verbose:
traceback.print_exc()
return False
def code_extract(text: str) -> str:
lines = text.split("\n")
longest_line_pair = (0, 0)
longest_so_far = 0
for i in range(len(lines)):
for j in range(i + 1, len(lines)):
current_lines = "\n".join(lines[i : j + 1])
if syntax_check(current_lines):
current_length = sum(1 for line in lines[i : j + 1] if line.strip())
if current_length > longest_so_far:
longest_so_far = current_length
longest_line_pair = (i, j)
return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1])
def get_definition_name(node: Node) -> str:
for child in node.children:
if child.type == NodeType.IDENTIFIER.value:
return child.text.decode("utf8")
def has_return_statement(node: Node) -> bool:
traverse_nodes = traverse_tree(node)
for node in traverse_nodes:
if node.type == NodeType.RETURN.value:
return True
return False
def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]:
def dfs_get_deps(node: Node, deps: Set[str]) -> None:
for child in node.children:
if child.type == NodeType.IDENTIFIER.value:
deps.add(child.text.decode("utf8"))
else:
dfs_get_deps(child, deps)
name2deps = {}
for name, node in nodes:
deps = set()
dfs_get_deps(node, deps)
name2deps[name] = deps
return name2deps
def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]:
queue = [entrypoint]
visited = {entrypoint}
while queue:
current = queue.pop(0)
if current not in call_graph:
continue
for neighbour in call_graph[current]:
if neighbour not in visited:
visited.add(neighbour)
queue.append(neighbour)
return visited
def sanitize(code: str, entrypoint: Optional[str] = None) -> str:
"""
Sanitize and extract relevant parts of the given Python code.
This function parses the input code, extracts import statements, class and function definitions,
and variable assignments. If an entrypoint is provided, it only includes definitions that are
reachable from the entrypoint in the call graph.
:param code: The input Python code as a string.
:param entrypoint: Optional name of a function to use as the entrypoint for dependency analysis.
:return: A sanitized version of the input code, containing only relevant parts.
"""
code = code_extract(code)
code_bytes = bytes(code, "utf8")
parser = Parser(Language(tree_sitter_python.language()))
tree = parser.parse(code_bytes)
class_names = set()
function_names = set()
variable_names = set()
root_node = tree.root_node
import_nodes = []
definition_nodes = []
for child in root_node.children:
if child.type in NodeType.IMPORT.value:
import_nodes.append(child)
elif child.type == NodeType.CLASS.value:
name = get_definition_name(child)
if not (name in class_names or name in variable_names or name in function_names):
definition_nodes.append((name, child))
class_names.add(name)
elif child.type == NodeType.FUNCTION.value:
name = get_definition_name(child)
if not (name in function_names or name in variable_names or name in class_names) and has_return_statement(
child
):
definition_nodes.append((name, child))
function_names.add(get_definition_name(child))
elif child.type == NodeType.EXPRESSION.value and child.children[0].type == NodeType.ASSIGNMENT.value:
subchild = child.children[0]
name = get_definition_name(subchild)
if not (name in variable_names or name in function_names or name in class_names):
definition_nodes.append((name, subchild))
variable_names.add(name)
if entrypoint:
name2deps = get_deps(definition_nodes)
reacheable = get_function_dependency(entrypoint, name2deps)
sanitized_output = b""
for node in import_nodes:
sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
for pair in definition_nodes:
name, node = pair
if entrypoint and name not in reacheable:
continue
sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
return sanitized_output[:-1].decode("utf8")