mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-04 13:22:39 +02:00
Update mbpp & math's eval
This commit is contained in:
parent
efa00f8bbb
commit
c194415b35
2 changed files with 181 additions and 544 deletions
|
|
@ -1,396 +1,116 @@
|
|||
import re
|
||||
import regex
|
||||
from pandas import Series
|
||||
from sympy import N, simplify
|
||||
from sympy.parsing.latex import parse_latex
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
from math import isclose
|
||||
import multiprocessing
|
||||
import json
|
||||
import asyncio
|
||||
import aiofiles
|
||||
import pandas as pd
|
||||
from typing import Optional, List, Tuple, Callable, Union, Any
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
from datetime import datetime
|
||||
import os
|
||||
import inspect
|
||||
from typing import Any, Callable, Tuple, List
|
||||
|
||||
from examples.aflow.benchmark.benchmark import BaseBenchmark
|
||||
|
||||
def extract_model_answer(text: str) -> str:
|
||||
# 提取最后一个 \boxed{...}
|
||||
pattern = r"\\boxed{((?:[^{}]|{[^{}]*})*)}"
|
||||
boxed_matches = re.findall(pattern, text, re.DOTALL)
|
||||
if boxed_matches:
|
||||
return boxed_matches[-1].strip()
|
||||
class MATHBenchmark(BaseBenchmark):
|
||||
def __init__(self, name: str, file_path: str, log_path: str):
|
||||
super().__init__(name, file_path, log_path)
|
||||
|
||||
# 提取最后一句话
|
||||
sentence_end_pattern = r'(?<!\d)[.!?]\s+'
|
||||
sentences = re.split(sentence_end_pattern, text)
|
||||
sentences = [s.strip() for s in sentences if s.strip()]
|
||||
return sentences[-1] if sentences else ""
|
||||
def extract_model_answer(self, text: str) -> str:
|
||||
pattern = r"\\boxed{((?:[^{}]|{[^{}]*})*)}"
|
||||
boxed_matches = re.findall(pattern, text, re.DOTALL)
|
||||
if boxed_matches:
|
||||
return boxed_matches[-1].strip()
|
||||
|
||||
def extract_answer(text: str) -> str:
|
||||
# Look for the answer within \boxed{...}
|
||||
boxed_match = re.search(r"\\boxed{(.*?)}", text)
|
||||
if boxed_match:
|
||||
return boxed_match.group(1).strip()
|
||||
sentence_end_pattern = r'(?<!\d)[.!?]\s+'
|
||||
sentences = re.split(sentence_end_pattern, text)
|
||||
sentences = [s.strip() for s in sentences if s.strip()]
|
||||
return sentences[-1] if sentences else ""
|
||||
|
||||
sentence_end_pattern = r'(?<!\d)[.!?]\s+'
|
||||
sentences = re.split(sentence_end_pattern, text)
|
||||
def calculate_score(self, expected_output: str, prediction: str) -> Tuple[int, str]:
|
||||
expected_answer = self.extract_model_answer(expected_output)
|
||||
predicted_answer = self.extract_model_answer(prediction)
|
||||
|
||||
# 过滤空字符串并返回最后一个非空句子
|
||||
sentences = [s.strip() for s in sentences if s.strip()]
|
||||
return sentences[-1] if sentences else ""
|
||||
|
||||
def get_function_code(func):
|
||||
try:
|
||||
source_code = inspect.getsource(func)
|
||||
return source_code
|
||||
except OSError:
|
||||
return "no code"
|
||||
|
||||
def parse_digits(num):
|
||||
# format: 234.23 || 23%
|
||||
num = regex.sub(",", "", str(num))
|
||||
try:
|
||||
return float(num)
|
||||
except:
|
||||
if num.endswith("%"):
|
||||
num = num[:-1]
|
||||
if num.endswith("\\"):
|
||||
num = num[:-1]
|
||||
try:
|
||||
return float(num) / 100
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def is_digit(num):
|
||||
# paired with parse_digits
|
||||
return parse_digits(num) is not None
|
||||
|
||||
|
||||
def symbolic_equal(a, b):
|
||||
def _parse(s):
|
||||
for f in [parse_latex, parse_expr]:
|
||||
try:
|
||||
return f(s)
|
||||
except:
|
||||
pass
|
||||
return s
|
||||
|
||||
a = _parse(a)
|
||||
b = _parse(b)
|
||||
|
||||
try:
|
||||
if simplify(a - b) == 0:
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
if isclose(N(a), N(b), abs_tol=1e-3):
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def call_with_timeout(func, *args, timeout=5, **kwargs):
|
||||
output_queue = multiprocessing.Queue()
|
||||
process_args = args + (output_queue,)
|
||||
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
|
||||
process.start()
|
||||
process.join(timeout)
|
||||
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
process.join()
|
||||
return False
|
||||
|
||||
return output_queue.get()
|
||||
|
||||
|
||||
def math_equal(
|
||||
prediction: Union[bool, float, str],
|
||||
reference: Union[float, str],
|
||||
include_percentage: bool = True,
|
||||
is_close: bool = True,
|
||||
timeout: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Exact match of math if and only if:
|
||||
1. numerical equal: both can convert to float and are equal
|
||||
2. symbolic equal: both can convert to sympy expression and are equal
|
||||
"""
|
||||
if str(prediction) == str(reference):
|
||||
return True
|
||||
|
||||
try: # 1. numerical equal
|
||||
if is_digit(prediction) and is_digit(reference):
|
||||
prediction = parse_digits(prediction)
|
||||
reference = parse_digits(reference)
|
||||
# number questions
|
||||
if include_percentage:
|
||||
gt_result = [reference / 100, reference, reference * 100]
|
||||
else:
|
||||
gt_result = [reference]
|
||||
for item in gt_result:
|
||||
try:
|
||||
if is_close:
|
||||
if isclose(item, prediction, abs_tol=1e-3):
|
||||
return True
|
||||
else:
|
||||
if item == prediction:
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
except:
|
||||
pass
|
||||
|
||||
if not prediction and prediction not in [0, False]:
|
||||
return False
|
||||
|
||||
# 2. symbolic equal
|
||||
reference = str(reference).strip()
|
||||
prediction = str(prediction).strip()
|
||||
|
||||
if (
|
||||
regex.match(r"(\(|\[).+(\)|\])", prediction) is not None
|
||||
and regex.match(r"(\(|\[).+(\)|\])", reference) is not None
|
||||
):
|
||||
pred_parts = prediction[1:-1].split(",")
|
||||
ref_parts = reference[1:-1].split(",")
|
||||
if len(pred_parts) == len(ref_parts):
|
||||
if all(
|
||||
[
|
||||
math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close)
|
||||
for i in range(len(pred_parts))
|
||||
]
|
||||
):
|
||||
return True
|
||||
|
||||
if (
|
||||
(prediction.startswith("\\begin{pmatrix}") or prediction.startswith("\\begin{bmatrix}"))
|
||||
and (prediction.endswith("\\end{pmatrix}") or prediction.endswith("\\end{bmatrix}"))
|
||||
and (reference.startswith("\\begin{pmatrix}") or reference.startswith("\\begin{bmatrix}"))
|
||||
and (reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}"))
|
||||
):
|
||||
pred_lines = [
|
||||
line.strip()
|
||||
for line in prediction[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\")
|
||||
if line.strip()
|
||||
]
|
||||
ref_lines = [
|
||||
line.strip()
|
||||
for line in reference[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\")
|
||||
if line.strip()
|
||||
]
|
||||
matched = True
|
||||
if len(pred_lines) == len(ref_lines):
|
||||
for pred_line, ref_line in zip(pred_lines, ref_lines):
|
||||
pred_parts = pred_line.split("&")
|
||||
ref_parts = ref_line.split("&")
|
||||
if len(pred_parts) == len(ref_parts):
|
||||
if not all(
|
||||
[
|
||||
math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close)
|
||||
for i in range(len(pred_parts))
|
||||
]
|
||||
):
|
||||
matched = False
|
||||
break
|
||||
else:
|
||||
matched = False
|
||||
if not matched:
|
||||
break
|
||||
if self.math_equal(predicted_answer, expected_answer):
|
||||
return 1, predicted_answer
|
||||
else:
|
||||
matched = False
|
||||
if matched:
|
||||
return 0, predicted_answer
|
||||
|
||||
def math_equal(self, prediction: Any, reference: Any) -> bool:
|
||||
if str(prediction) == str(reference):
|
||||
return True
|
||||
|
||||
if prediction.count("=") == 1 and reference.count("=") == 1:
|
||||
pred = prediction.split("=")
|
||||
pred = f"{pred[0].strip()} - ({pred[1].strip()})"
|
||||
ref = reference.split("=")
|
||||
ref = f"{ref[0].strip()} - ({ref[1].strip()})"
|
||||
if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
|
||||
return True
|
||||
elif prediction.count("=") == 1 and len(prediction.split("=")[0].strip()) <= 2 and "=" not in reference:
|
||||
if math_equal(prediction.split("=")[1], reference, include_percentage, is_close):
|
||||
return True
|
||||
elif reference.count("=") == 1 and len(reference.split("=")[0].strip()) <= 2 and "=" not in prediction:
|
||||
if math_equal(prediction, reference.split("=")[1], include_percentage, is_close):
|
||||
return True
|
||||
try:
|
||||
if self.is_digit(prediction) and self.is_digit(reference):
|
||||
prediction = self.parse_digits(prediction)
|
||||
reference = self.parse_digits(reference)
|
||||
return isclose(prediction, reference, abs_tol=1e-3)
|
||||
except:
|
||||
pass
|
||||
|
||||
# symbolic equal with sympy
|
||||
if timeout:
|
||||
if call_with_timeout(symbolic_equal, prediction, reference):
|
||||
return True
|
||||
else:
|
||||
if symbolic_equal(prediction, reference):
|
||||
return True
|
||||
try:
|
||||
return self.symbolic_equal(prediction, reference)
|
||||
except:
|
||||
pass
|
||||
|
||||
return False
|
||||
return False
|
||||
|
||||
def is_digit(self, num):
|
||||
return self.parse_digits(num) is not None
|
||||
|
||||
def calculate_score(expected_output: str, prediction: str) -> tuple[int, str]:
|
||||
expected_answer = extract_model_answer(expected_output)
|
||||
predicted_answer = extract_model_answer(prediction)
|
||||
def parse_digits(self, num):
|
||||
num = regex.sub(",", "", str(num))
|
||||
try:
|
||||
return float(num)
|
||||
except:
|
||||
if num.endswith("%"):
|
||||
num = num[:-1]
|
||||
if num.endswith("\\"):
|
||||
num = num[:-1]
|
||||
try:
|
||||
return float(num) / 100
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
if math_equal(predicted_answer, expected_answer):
|
||||
return 1, predicted_answer
|
||||
else:
|
||||
return 0, predicted_answer
|
||||
def symbolic_equal(self, a, b):
|
||||
def _parse(s):
|
||||
for f in [parse_latex, parse_expr]:
|
||||
try:
|
||||
return f(s)
|
||||
except:
|
||||
pass
|
||||
return s
|
||||
|
||||
a = _parse(a)
|
||||
b = _parse(b)
|
||||
|
||||
def ensure_log_file_exists(path: str):
|
||||
log_file = os.path.join(path, 'log.json')
|
||||
if not os.path.exists(log_file):
|
||||
with open(log_file, 'w', encoding='utf-8') as f:
|
||||
json.dump([], f, indent=4, ensure_ascii=False)
|
||||
try:
|
||||
if simplify(a - b) == 0:
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
if isclose(N(a), N(b), abs_tol=1e-3):
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
def log_mismatch(problem: str, expected_output: float, prediction: str, predicted_number, path):
|
||||
log_data = {
|
||||
"question": problem,
|
||||
"right_answer": expected_output,
|
||||
"model_output": prediction,
|
||||
"extracted_output": predicted_number
|
||||
}
|
||||
async def evaluate_problem(self, problem: dict, graph: Callable) -> Tuple[str, str, str, int, float]:
|
||||
input_text = problem["problem"]
|
||||
expected_output = problem["solution"]
|
||||
max_retries = 2
|
||||
retries = 0
|
||||
|
||||
# 获取传入函数的源代码
|
||||
function_code = get_function_code(extract_model_answer)
|
||||
log_data["extract_answer_code"] = function_code # 新字段
|
||||
prediction = await graph(input_text)
|
||||
cost = prediction[1]
|
||||
output = prediction[0]
|
||||
|
||||
log_file = os.path.join(path, 'log.json')
|
||||
uni_score, extracted_output = self.calculate_score(expected_output, output)
|
||||
|
||||
# 检查log文件是否已经存在
|
||||
if os.path.exists(log_file):
|
||||
# 如果存在,加载现有的日志数据
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
try:
|
||||
data = json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
data = []
|
||||
else:
|
||||
# 如果不存在,创建一个新的日志列表
|
||||
data = []
|
||||
if uni_score == 0:
|
||||
self.log_mismatch(input_text, expected_output, output, extracted_output)
|
||||
|
||||
# 添加新的日志记录
|
||||
data.append(log_data)
|
||||
return input_text, output, expected_output, uni_score, cost
|
||||
|
||||
# 将数据写回到log.json文件
|
||||
with open(log_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
async def load_data(file_path: str, specific_indices: List[int] = None) -> List[dict]:
|
||||
data = []
|
||||
# 异步读取文件内容
|
||||
async with aiofiles.open(file_path, mode="r", encoding='utf-8') as file:
|
||||
async for line in file:
|
||||
data.append(json.loads(line))
|
||||
|
||||
# 然后在随机选择的样本中基于特定索引列表进行进一步筛选
|
||||
if specific_indices is not None:
|
||||
filtered_data = [data[i] for i in specific_indices if i < len(data)]
|
||||
return filtered_data
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def save_results_to_csv(results: List[Tuple[str, str, str, int]], path):
|
||||
# 创建 DataFrame
|
||||
df = pd.DataFrame(results, columns=["question", "prediction", "expected_output", "score", "cost"])
|
||||
|
||||
# 计算统计数据
|
||||
avg_score = df["score"].mean()
|
||||
t_cost = df["cost"].max()
|
||||
a_cost = t_cost / len(df) if len(df) > 0 else 0
|
||||
|
||||
# 获取当前时间,格式为 YYYYMMDD_HHMMSS
|
||||
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# 生成文件名,包含平均分和当前时间,保留五位小数
|
||||
filename = f"{avg_score:.5f}_{current_time}.csv"
|
||||
output_file = os.path.join(path, filename)
|
||||
|
||||
# 保存到 CSV
|
||||
df.to_csv(output_file, index=False)
|
||||
print(f"Results saved to {output_file}")
|
||||
|
||||
return avg_score, a_cost, t_cost
|
||||
|
||||
|
||||
async def evaluate_problem(problem: dict, graph, log_path) -> Tuple[str, str, str, int, str]:
|
||||
input_text = problem["problem"]
|
||||
expected_output = problem["solution"]
|
||||
max_retries = 2
|
||||
retries = 0
|
||||
|
||||
|
||||
|
||||
prediction = await graph(input_text) if graph else "None"
|
||||
cost = prediction[1]
|
||||
output = prediction[0]
|
||||
|
||||
uni_score, extracted_output = calculate_score(expected_output, output)
|
||||
|
||||
if uni_score == 0:
|
||||
log_mismatch(input_text, expected_output, output, extracted_output, log_path)
|
||||
else:
|
||||
ensure_log_file_exists(log_path)
|
||||
|
||||
# while retries < max_retries:
|
||||
# try:
|
||||
# prediction = await graph(input_text) if graph else "None"
|
||||
# cost = prediction[1]
|
||||
# output = prediction[0]
|
||||
|
||||
# uni_score, extracted_output = calculate_score(expected_output, output)
|
||||
|
||||
# if uni_score == 0:
|
||||
# log_mismatch(input_text, expected_output, output, extracted_output, log_path)
|
||||
# else:
|
||||
# ensure_log_file_exists(log_path)
|
||||
|
||||
# break
|
||||
|
||||
# except Exception as e:
|
||||
# retries += 1
|
||||
# print(f"Error generating prediction: {e}. Retrying... ({retries}/{max_retries})")
|
||||
|
||||
# if retries == max_retries:
|
||||
# print("Maximum retries reached. Skipping this sample.")
|
||||
# output = e
|
||||
# cost = None
|
||||
# uni_score = 0
|
||||
# break
|
||||
|
||||
return input_text, output, expected_output, uni_score, cost
|
||||
|
||||
|
||||
async def evaluate_all_problems(data: List[dict], graph, path, max_concurrent_tasks: int = 300):
|
||||
semaphore = asyncio.Semaphore(max_concurrent_tasks)
|
||||
|
||||
async def sem_evaluate(problem):
|
||||
async with semaphore:
|
||||
return await evaluate_problem(problem, graph, path)
|
||||
|
||||
tasks = [sem_evaluate(problem) for problem in data]
|
||||
|
||||
return await tqdm_asyncio.gather(*tasks, desc="Evaluating MATH problems", total=len(data))
|
||||
|
||||
|
||||
async def optimize_math_evaluation(graph: Callable, file_path: str, path: str, va_list: list) -> tuple[
|
||||
Any, Any, Any]:
|
||||
data = await load_data(file_path, va_list)
|
||||
results = await evaluate_all_problems(data, graph, path, max_concurrent_tasks=30)
|
||||
average_score, average_cost, total_cost = save_results_to_csv(results, path=path)
|
||||
print(f"Average score on MATH dataset: {average_score:.5f}")
|
||||
print(f"Total Cost: {total_cost:.5f}")
|
||||
return average_score, average_cost, total_cost
|
||||
def get_result_columns(self) -> List[str]:
|
||||
return ["question", "prediction", "expected_output", "score", "cost"]
|
||||
|
|
|
|||
|
|
@ -2,205 +2,122 @@ import os
|
|||
import json
|
||||
import time
|
||||
import asyncio
|
||||
import aiofiles
|
||||
import threading
|
||||
import pandas as pd
|
||||
from typing import List, Tuple, Callable, Any, Optional, Dict
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple, Callable, Any, Optional, Dict
|
||||
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
from examples.aflow.benchmark.utils import log_mismatch
|
||||
from metagpt.actions.code_sanitize import sanitize
|
||||
from examples.aflow.benchmark.utils import generate_random_indices
|
||||
from examples.aflow.benchmark.benchmark import BaseBenchmark
|
||||
|
||||
PASS = "pass"
|
||||
FAIL = "fail"
|
||||
class MBPPBenchmark(BaseBenchmark):
|
||||
def __init__(self, name: str, file_path: str, log_path: str):
|
||||
super().__init__(name, file_path, log_path)
|
||||
|
||||
async def load_data(file_path: str, samples=1, test=False) -> List[dict]:
|
||||
data = []
|
||||
async with aiofiles.open(file_path, mode="r") as file:
|
||||
async for line in file:
|
||||
data.append(json.loads(line))
|
||||
random_indices = generate_random_indices(len(data), samples, test)
|
||||
data = [data[i] for i in random_indices]
|
||||
return data
|
||||
PASS = "PASS"
|
||||
FAIL = "FAIL"
|
||||
|
||||
class TimeoutError(Exception):
|
||||
pass
|
||||
|
||||
PASS = "PASS"
|
||||
FAIL = "FAIL"
|
||||
def run_with_timeout(self, func, timeout):
|
||||
result = []
|
||||
stop_event = threading.Event()
|
||||
|
||||
class TimeoutError(Exception):
|
||||
pass
|
||||
def target():
|
||||
try:
|
||||
result.append(func())
|
||||
except Exception as e:
|
||||
result.append(e)
|
||||
finally:
|
||||
stop_event.set()
|
||||
|
||||
def run_with_timeout(func, timeout):
|
||||
result = []
|
||||
stop_event = threading.Event()
|
||||
thread = threading.Thread(target=target)
|
||||
thread.start()
|
||||
is_timeout = not stop_event.wait(timeout)
|
||||
|
||||
def target():
|
||||
if is_timeout:
|
||||
raise self.TimeoutError("Function execution timed out")
|
||||
|
||||
if not result:
|
||||
return None
|
||||
if isinstance(result[0], Exception):
|
||||
raise result[0]
|
||||
return result[0]
|
||||
|
||||
def check_solution(self, solution, test, entry_point):
|
||||
solution = sanitize(code=solution, entrypoint=entry_point)
|
||||
try:
|
||||
result.append(func())
|
||||
global_dict = {
|
||||
'math': __import__('math'),
|
||||
'hashlib': __import__('hashlib'),
|
||||
're': __import__('re'),
|
||||
'List': List,
|
||||
'Dict': Dict,
|
||||
'Tuple': Tuple,
|
||||
'Optional': Optional,
|
||||
'Any': Any
|
||||
}
|
||||
|
||||
exec(solution, global_dict)
|
||||
|
||||
if entry_point not in global_dict:
|
||||
raise ValueError(f"Function {entry_point} is not defined in the solution.")
|
||||
|
||||
exec(test, global_dict)
|
||||
|
||||
check = global_dict["check"]
|
||||
|
||||
result = self.run_with_timeout(check, 15)
|
||||
|
||||
if result is None:
|
||||
result = (self.PASS, "The solution passed all test cases.")
|
||||
|
||||
except self.TimeoutError:
|
||||
result = (self.FAIL, "Execution timed out. Please check if your solution contains infinite loops or overly time-consuming operations.")
|
||||
except Exception as e:
|
||||
result.append(e)
|
||||
finally:
|
||||
stop_event.set()
|
||||
|
||||
thread = threading.Thread(target=target)
|
||||
thread.start()
|
||||
is_timeout = not stop_event.wait(timeout)
|
||||
|
||||
if is_timeout:
|
||||
# 线程仍在运行,我们无法强制终止它,但至少可以标记超时
|
||||
raise TimeoutError("Function execution timed out")
|
||||
|
||||
if not result:
|
||||
return None
|
||||
if isinstance(result[0], Exception):
|
||||
raise result[0]
|
||||
return result[0]
|
||||
|
||||
def check_solution(solution, test, entry_point):
|
||||
|
||||
solution = sanitize(code=solution, entrypoint=entry_point)
|
||||
try:
|
||||
# 定义一个包含所有必要模块的全局字典
|
||||
global_dict = {
|
||||
'math': __import__('math'),
|
||||
'hashlib': __import__('hashlib'),
|
||||
're': __import__('re'),
|
||||
'List': List,
|
||||
'Dict': Dict,
|
||||
'Tuple': Tuple,
|
||||
'Optional': Optional,
|
||||
'Any': Any
|
||||
}
|
||||
# 执行解决方案
|
||||
exec(solution, global_dict)
|
||||
error_message = f"Error: {str(e)}.\n Solution: {solution}.\n Test: {test}"
|
||||
result = (self.FAIL, error_message)
|
||||
|
||||
with open('error.log', 'a', encoding='utf-8') as log_file:
|
||||
log_file.write(f"{time.strftime('%Y-%m-%d %H:%M:%S')} - {error_message}\n")
|
||||
|
||||
# 确保入口点函数已定义
|
||||
if entry_point not in global_dict:
|
||||
raise ValueError(f"函数 {entry_point} 在解决方案中未定义。")
|
||||
|
||||
# 执行测试用例
|
||||
exec(test, global_dict)
|
||||
|
||||
# 获取检查函数
|
||||
check = global_dict["check"]
|
||||
|
||||
# 运行检查函数,设置超时时间为120秒
|
||||
result = run_with_timeout(check, 15)
|
||||
|
||||
if result is None:
|
||||
result = (PASS, "解决方案通过了所有测试用例。")
|
||||
|
||||
except TimeoutError:
|
||||
result = (FAIL, "执行超时。请检查您的解决方案是否包含无限循环或过于耗时的操作。")
|
||||
except Exception as e:
|
||||
# 记录详细的错误信息
|
||||
error_message = f"错误: {str(e)}.\n 解决方案: {solution}.\n 测试: {test}"
|
||||
result = (FAIL, error_message)
|
||||
|
||||
# 将错误信息写入error.log文件
|
||||
with open('error.log', 'a', encoding='utf-8') as log_file:
|
||||
log_file.write(f"{time.strftime('%Y-%m-%d %H:%M:%S')} - {error_message}\n")
|
||||
|
||||
return result
|
||||
return result
|
||||
|
||||
async def evaluate_problem(data: dict, graph: Callable, path) -> Tuple[str, str, str, int, str]:
|
||||
max_retries = 5
|
||||
retries = 0
|
||||
async def evaluate_problem(self, data: dict, graph: Callable) -> Tuple[str, str, str, float, float]:
|
||||
max_retries = 5
|
||||
retries = 0
|
||||
|
||||
expected_output = "\nCorrect Solution:\ndef " + data["code"]
|
||||
expected_output = "\nCorrect Solution:\ndef " + data["code"]
|
||||
|
||||
while retries < max_retries:
|
||||
try:
|
||||
prediction = await graph(data["prompt"], data["entry_point"]) if graph else "None"
|
||||
cost = prediction[1]
|
||||
solution = prediction[0]
|
||||
ret = check_solution(solution, data["test"], data["entry_point"])
|
||||
test_case_details = ret[1]
|
||||
expected_output = test_case_details + "\nCorrect Solution:" + data["code"]
|
||||
score = 1 if ret[0] == PASS else 0
|
||||
while retries < max_retries:
|
||||
try:
|
||||
prediction, cost = await graph(data["prompt"], data["entry_point"])
|
||||
ret = self.check_solution(prediction, data["test"], data["entry_point"])
|
||||
test_case_details = ret[1]
|
||||
expected_output = test_case_details + "\nCorrect Solution:" + data["code"]
|
||||
score = 1.0 if ret[0] == self.PASS else 0.0
|
||||
|
||||
if score == 0:
|
||||
log_mismatch(data["prompt"], expected_output, solution, score, path)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
print(f"Error generating prediction: {e}. Retrying... ({retries}/{max_retries})")
|
||||
|
||||
if retries == max_retries:
|
||||
print("Maximum retries reached. Skipping this sample.")
|
||||
solution = None
|
||||
ret = (FAIL, [])
|
||||
score = 0
|
||||
cost = 0
|
||||
if score == 0:
|
||||
self.log_mismatch(data["prompt"], expected_output, prediction, score)
|
||||
break
|
||||
|
||||
return data["prompt"], solution, expected_output, score, cost
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
print(f"Error generating prediction: {e}. Retrying... ({retries}/{max_retries})")
|
||||
|
||||
async def evaluate_all_problems(data: List[dict], graph: Callable, path:str="", max_concurrent_tasks: int = 50) -> List[Tuple[str, str, str, int, str]]:
|
||||
semaphore = asyncio.Semaphore(max_concurrent_tasks)
|
||||
if retries == max_retries:
|
||||
print("Maximum retries reached. Skipping this sample.")
|
||||
prediction = None
|
||||
ret = (self.FAIL, [])
|
||||
score = 0.0
|
||||
cost = 0.0
|
||||
break
|
||||
|
||||
async def sem_evaluate(problem):
|
||||
async with semaphore:
|
||||
return await evaluate_problem(problem, graph, path)
|
||||
return data["prompt"], prediction, expected_output, score, cost
|
||||
|
||||
tasks = [sem_evaluate(problem) for problem in data]
|
||||
def calculate_score(self, expected_output: str, prediction: str) -> Tuple[float, str]:
|
||||
# The scoring logic for MBPP is already implemented in evaluate_problem, this is just to conform to the interface
|
||||
return 0.0, prediction
|
||||
|
||||
return await tqdm_asyncio.gather(*tasks, desc="Evaluating MBPP problems", total=len(data))
|
||||
|
||||
def save_results_to_csv(results: List[Tuple[str, str, str, int]], path):
|
||||
# 创建 DataFrame
|
||||
df = pd.DataFrame(results, columns=["question", "prediction", "expected_output", "score", "cost"])
|
||||
|
||||
# 计算统计数据
|
||||
avg_score = df["score"].mean()
|
||||
t_cost = df["cost"].max()
|
||||
a_cost = t_cost / len(df) if len(df) > 0 else 0
|
||||
|
||||
# 获取当前时间,格式为 YYYYMMDD_HHMMSS
|
||||
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# 生成文件名,包含平均分和当前时间,保留五位小数
|
||||
filename = f"{avg_score:.5f}_{current_time}.csv"
|
||||
output_file = os.path.join(path, filename)
|
||||
|
||||
# 保存到 CSV
|
||||
df.to_csv(output_file, index=False)
|
||||
print(f"Results saved to {output_file}")
|
||||
|
||||
return avg_score, a_cost, t_cost
|
||||
|
||||
|
||||
async def load_file_data(file_path: str, specific_indices: List[int] = None) -> List[dict]:
|
||||
data = []
|
||||
# 异步读取文件内容
|
||||
async with aiofiles.open(file_path, mode="r", encoding='utf-8') as file:
|
||||
async for line in file:
|
||||
data.append(json.loads(line))
|
||||
|
||||
# 然后在随机选择的样本中基于特定索引列表进行进一步筛选
|
||||
if specific_indices is not None:
|
||||
filtered_data = [data[i] for i in specific_indices if i < len(data)]
|
||||
return filtered_data
|
||||
|
||||
return data
|
||||
|
||||
async def mbpp_evaluation(graph: Callable, file_path: str, samples: int, path: str, test=False) -> Tuple[float, float]:
|
||||
data = await load_data(file_path, samples, test)
|
||||
results = await evaluate_all_problems(data, graph, max_concurrent_tasks=25)
|
||||
average_score, total_cost = save_results_to_csv(results, path=path)
|
||||
print(f"Average score on MBPP dataset: {average_score:.5f}")
|
||||
print(f"Total Cost: {total_cost:.5f}")
|
||||
return average_score, total_cost
|
||||
|
||||
|
||||
async def optimize_mbpp_evaluation(graph: Callable, file_path: str, path: str, va_list: List[int]) -> Tuple[float, float]:
|
||||
data = await load_file_data(file_path, va_list)
|
||||
results = await evaluate_all_problems(data, graph, path, max_concurrent_tasks=25)
|
||||
average_score, average_cost, total_cost = save_results_to_csv(results, path=path)
|
||||
print(f"Average score on MBPP dataset: {average_score:.5f}")
|
||||
print(f"Total Cost: {total_cost:.5f}")
|
||||
print(f"Average cost on MBPP dataset: {average_cost:.5f}")
|
||||
return average_score, average_cost, total_cost
|
||||
def get_result_columns(self) -> List[str]:
|
||||
return ["inputs", "prediction", "expected_output", "score", "cost"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue