mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-15 11:02:36 +02:00
Update Latest Review
This commit is contained in:
parent
92e520ded2
commit
f0a3a3f739
15 changed files with 276 additions and 348 deletions
|
|
@ -3,6 +3,7 @@ import json
|
|||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Tuple
|
||||
|
||||
import aiofiles
|
||||
|
|
@ -10,6 +11,7 @@ import pandas as pd
|
|||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import write_json_file
|
||||
|
||||
|
||||
class BaseBenchmark(ABC):
|
||||
|
|
@ -18,6 +20,9 @@ class BaseBenchmark(ABC):
|
|||
self.file_path = file_path
|
||||
self.log_path = log_path
|
||||
|
||||
PASS = "PASS"
|
||||
FAIL = "FAIL"
|
||||
|
||||
async def load_data(self, specific_indices: List[int] = None) -> List[dict]:
|
||||
data = []
|
||||
async with aiofiles.open(self.file_path, mode="r", encoding="utf-8") as file:
|
||||
|
|
@ -55,9 +60,9 @@ class BaseBenchmark(ABC):
|
|||
"extracted_output": extracted_output,
|
||||
"extract_answer_code": extract_answer_code,
|
||||
}
|
||||
log_file = os.path.join(self.log_path, "log.json")
|
||||
if os.path.exists(log_file):
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
log_file = Path(self.log_path) / "log.json"
|
||||
if log_file.exists():
|
||||
with log_file.open("r", encoding="utf-8") as f:
|
||||
try:
|
||||
data = json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
|
|
@ -65,8 +70,7 @@ class BaseBenchmark(ABC):
|
|||
else:
|
||||
data = []
|
||||
data.append(log_data)
|
||||
with open(log_file, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
write_json_file(log_file, data, encoding="utf-8", indent=4)
|
||||
|
||||
@abstractmethod
|
||||
async def evaluate_problem(self, problem: dict, graph: Callable) -> Tuple[Any, ...]:
|
||||
|
|
|
|||
|
|
@ -6,17 +6,14 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
|||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
|
||||
|
||||
from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark
|
||||
from metagpt.ext.aflow.scripts.utils import sanitize
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.sanitize import sanitize
|
||||
|
||||
|
||||
class HumanEvalBenchmark(BaseBenchmark):
|
||||
def __init__(self, name: str, file_path: str, log_path: str):
|
||||
super().__init__(name, file_path, log_path)
|
||||
|
||||
PASS = "PASS"
|
||||
FAIL = "FAIL"
|
||||
|
||||
class TimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -5,17 +5,14 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
|||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
|
||||
|
||||
from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark
|
||||
from metagpt.ext.aflow.scripts.utils import sanitize
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.sanitize import sanitize
|
||||
|
||||
|
||||
class MBPPBenchmark(BaseBenchmark):
|
||||
def __init__(self, name: str, file_path: str, log_path: str):
|
||||
super().__init__(name, file_path, log_path)
|
||||
|
||||
PASS = "PASS"
|
||||
FAIL = "FAIL"
|
||||
|
||||
class TimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
@Time : 2024/7/24 16:37
|
||||
@Author : didi
|
||||
@File : utils.py
|
||||
@Acknowledgement https://github.com/evalplus/evalplus/blob/master/evalplus/sanitize.py
|
||||
"""
|
||||
|
||||
import json
|
||||
|
|
@ -12,6 +11,8 @@ import os
|
|||
|
||||
import numpy as np
|
||||
|
||||
from metagpt.utils.common import write_json_file
|
||||
|
||||
|
||||
def generate_random_indices(n, n_samples, test=False):
|
||||
"""
|
||||
|
|
@ -41,13 +42,6 @@ def split_data_set(file_path, samples, test=False):
|
|||
return data
|
||||
|
||||
|
||||
# save data into a jsonl file
|
||||
def save_data(data, file_path):
|
||||
with open(file_path, "w") as file:
|
||||
for d in data:
|
||||
file.write(json.dumps(d) + "\n")
|
||||
|
||||
|
||||
def log_mismatch(problem, expected_output, prediction, predicted_number, path):
|
||||
log_data = {
|
||||
"question": problem,
|
||||
|
|
@ -74,5 +68,4 @@ def log_mismatch(problem, expected_output, prediction, predicted_number, path):
|
|||
data.append(log_data)
|
||||
|
||||
# 将数据写回到log.json文件
|
||||
with open(log_file, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
write_json_file(log_file, data, encoding="utf-8", indent=4)
|
||||
|
|
|
|||
|
|
@ -2,62 +2,11 @@
|
|||
@Time : 2024/7/24 16:37
|
||||
@Author : didi
|
||||
@File : utils.py
|
||||
@Acknowledgement https://github.com/evalplus/evalplus/blob/master/evalplus/sanitize.py
|
||||
"""
|
||||
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Generator, List, Optional, Set, Tuple
|
||||
|
||||
import tree_sitter_python
|
||||
from tree_sitter import Language, Node, Parser
|
||||
|
||||
|
||||
def extract_task_id(task_id: str) -> int:
|
||||
"""Extract the numeric part of the task_id."""
|
||||
match = re.search(r"/(\d+)", task_id)
|
||||
return int(match.group(1)) if match else 0
|
||||
|
||||
|
||||
def get_hotpotqa(path: str):
|
||||
# Parses each jsonl line and yields it as a dictionary
|
||||
def parse_jsonl(path):
|
||||
with open(path) as f:
|
||||
for line in f:
|
||||
yield json.loads(line)
|
||||
|
||||
datas = list(parse_jsonl(path))
|
||||
return {data["_id"]: data for data in datas}
|
||||
|
||||
|
||||
def sort_json_by_key(input_file: str, output_file: str, key: str = "task_id"):
|
||||
"""
|
||||
Read a JSONL file, sort the entries based on task_id, and write to a new JSONL file.
|
||||
|
||||
:param input_file: Path to the input JSONL file
|
||||
:param output_file: Path to the output JSONL file
|
||||
"""
|
||||
# Read and parse the JSONL file
|
||||
with open(input_file, "r") as f:
|
||||
data = [json.loads(line) for line in f]
|
||||
|
||||
# Sort the data based on the numeric part of task_id
|
||||
sorted_data = sorted(data, key=lambda x: extract_task_id(x[key]))
|
||||
|
||||
# Write the sorted data to a new JSONL file
|
||||
with open(output_file, "w") as f:
|
||||
for item in sorted_data:
|
||||
f.write(json.dumps(item) + "\n")
|
||||
|
||||
|
||||
def parse_python_literal(s):
|
||||
try:
|
||||
return ast.literal_eval(s)
|
||||
except (ValueError, SyntaxError):
|
||||
return s
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
|
||||
def extract_test_cases_from_jsonl(entry_point: str, dataset: str = "HumanEval"):
|
||||
|
|
@ -168,172 +117,3 @@ def test_check():
|
|||
test_check()
|
||||
"""
|
||||
return tester_function
|
||||
|
||||
|
||||
class NodeType(Enum):
|
||||
CLASS = "class_definition"
|
||||
FUNCTION = "function_definition"
|
||||
IMPORT = ["import_statement", "import_from_statement"]
|
||||
IDENTIFIER = "identifier"
|
||||
ATTRIBUTE = "attribute"
|
||||
RETURN = "return_statement"
|
||||
EXPRESSION = "expression_statement"
|
||||
ASSIGNMENT = "assignment"
|
||||
|
||||
|
||||
def traverse_tree(node: Node) -> Generator[Node, None, None]:
|
||||
"""
|
||||
Traverse the tree structure starting from the given node.
|
||||
|
||||
:param node: The root node to start the traversal from.
|
||||
:return: A generator object that yields nodes in the tree.
|
||||
"""
|
||||
cursor = node.walk()
|
||||
depth = 0
|
||||
|
||||
visited_children = False
|
||||
while True:
|
||||
if not visited_children:
|
||||
yield cursor.node
|
||||
if not cursor.goto_first_child():
|
||||
depth += 1
|
||||
visited_children = True
|
||||
elif cursor.goto_next_sibling():
|
||||
visited_children = False
|
||||
elif not cursor.goto_parent() or depth == 0:
|
||||
break
|
||||
else:
|
||||
depth -= 1
|
||||
|
||||
|
||||
def syntax_check(code, verbose=False):
|
||||
try:
|
||||
ast.parse(code)
|
||||
return True
|
||||
except (SyntaxError, MemoryError):
|
||||
if verbose:
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def code_extract(text: str) -> str:
|
||||
lines = text.split("\n")
|
||||
longest_line_pair = (0, 0)
|
||||
longest_so_far = 0
|
||||
|
||||
for i in range(len(lines)):
|
||||
for j in range(i + 1, len(lines)):
|
||||
current_lines = "\n".join(lines[i : j + 1])
|
||||
if syntax_check(current_lines):
|
||||
current_length = sum(1 for line in lines[i : j + 1] if line.strip())
|
||||
if current_length > longest_so_far:
|
||||
longest_so_far = current_length
|
||||
longest_line_pair = (i, j)
|
||||
|
||||
return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1])
|
||||
|
||||
|
||||
def get_definition_name(node: Node) -> str:
|
||||
for child in node.children:
|
||||
if child.type == NodeType.IDENTIFIER.value:
|
||||
return child.text.decode("utf8")
|
||||
|
||||
|
||||
def has_return_statement(node: Node) -> bool:
|
||||
traverse_nodes = traverse_tree(node)
|
||||
for node in traverse_nodes:
|
||||
if node.type == NodeType.RETURN.value:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]:
|
||||
def dfs_get_deps(node: Node, deps: Set[str]) -> None:
|
||||
for child in node.children:
|
||||
if child.type == NodeType.IDENTIFIER.value:
|
||||
deps.add(child.text.decode("utf8"))
|
||||
else:
|
||||
dfs_get_deps(child, deps)
|
||||
|
||||
name2deps = {}
|
||||
for name, node in nodes:
|
||||
deps = set()
|
||||
dfs_get_deps(node, deps)
|
||||
name2deps[name] = deps
|
||||
return name2deps
|
||||
|
||||
|
||||
def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]:
|
||||
queue = [entrypoint]
|
||||
visited = {entrypoint}
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
if current not in call_graph:
|
||||
continue
|
||||
for neighbour in call_graph[current]:
|
||||
if neighbour not in visited:
|
||||
visited.add(neighbour)
|
||||
queue.append(neighbour)
|
||||
return visited
|
||||
|
||||
|
||||
def sanitize(code: str, entrypoint: Optional[str] = None) -> str:
|
||||
"""
|
||||
Sanitize and extract relevant parts of the given Python code.
|
||||
This function parses the input code, extracts import statements, class and function definitions,
|
||||
and variable assignments. If an entrypoint is provided, it only includes definitions that are
|
||||
reachable from the entrypoint in the call graph.
|
||||
|
||||
:param code: The input Python code as a string.
|
||||
:param entrypoint: Optional name of a function to use as the entrypoint for dependency analysis.
|
||||
:return: A sanitized version of the input code, containing only relevant parts.
|
||||
"""
|
||||
code = code_extract(code)
|
||||
code_bytes = bytes(code, "utf8")
|
||||
parser = Parser(Language(tree_sitter_python.language()))
|
||||
tree = parser.parse(code_bytes)
|
||||
class_names = set()
|
||||
function_names = set()
|
||||
variable_names = set()
|
||||
|
||||
root_node = tree.root_node
|
||||
import_nodes = []
|
||||
definition_nodes = []
|
||||
|
||||
for child in root_node.children:
|
||||
if child.type in NodeType.IMPORT.value:
|
||||
import_nodes.append(child)
|
||||
elif child.type == NodeType.CLASS.value:
|
||||
name = get_definition_name(child)
|
||||
if not (name in class_names or name in variable_names or name in function_names):
|
||||
definition_nodes.append((name, child))
|
||||
class_names.add(name)
|
||||
elif child.type == NodeType.FUNCTION.value:
|
||||
name = get_definition_name(child)
|
||||
if not (name in function_names or name in variable_names or name in class_names) and has_return_statement(
|
||||
child
|
||||
):
|
||||
definition_nodes.append((name, child))
|
||||
function_names.add(get_definition_name(child))
|
||||
elif child.type == NodeType.EXPRESSION.value and child.children[0].type == NodeType.ASSIGNMENT.value:
|
||||
subchild = child.children[0]
|
||||
name = get_definition_name(subchild)
|
||||
if not (name in variable_names or name in function_names or name in class_names):
|
||||
definition_nodes.append((name, subchild))
|
||||
variable_names.add(name)
|
||||
|
||||
if entrypoint:
|
||||
name2deps = get_deps(definition_nodes)
|
||||
reacheable = get_function_dependency(entrypoint, name2deps)
|
||||
|
||||
sanitized_output = b""
|
||||
|
||||
for node in import_nodes:
|
||||
sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
|
||||
|
||||
for pair in definition_nodes:
|
||||
name, node = pair
|
||||
if entrypoint and name not in reacheable:
|
||||
continue
|
||||
sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
|
||||
return sanitized_output[:-1].decode("utf8")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue