mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-21 14:05:17 +02:00
重构了Evaluator
This commit is contained in:
parent
4e0a896bdc
commit
7ffe68b499
6 changed files with 846 additions and 1212 deletions
186
examples/ags/benchmark/drop.py
Normal file
186
examples/ags/benchmark/drop.py
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
import json
|
||||
import asyncio
|
||||
import pandas as pd
|
||||
from typing import List, Tuple, Callable, Dict, Any, Set
|
||||
import numpy as np
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
def is_number(text: str) -> bool:
|
||||
try:
|
||||
float(text)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def normalize_answer(text):
|
||||
import re
|
||||
import string
|
||||
|
||||
def remove_articles(text):
|
||||
return re.sub(r"\b(a|an|the)\b", " ", text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return " ".join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return "".join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
def tokenize(text):
|
||||
return re.split(" |-", text)
|
||||
|
||||
def normalize_number(text: str) -> str:
|
||||
if is_number(text):
|
||||
return str(float(text))
|
||||
else:
|
||||
return text
|
||||
|
||||
parts = [
|
||||
white_space_fix(remove_articles(normalize_number(remove_punc(lower(token)))))
|
||||
for token in tokenize(text)
|
||||
]
|
||||
parts = [part for part in parts if part.strip()]
|
||||
normalized = " ".join(parts).strip()
|
||||
return normalized
|
||||
|
||||
def answer_to_bags(answer: str) -> Set[str]:
|
||||
raw_spans = [answer]
|
||||
|
||||
normalized_spans = []
|
||||
token_bags = []
|
||||
for raw_span in raw_spans:
|
||||
normalized_span = normalize_answer(raw_span)
|
||||
normalized_spans.append(normalized_span)
|
||||
token_bags.append(set(normalized_span.split()))
|
||||
return normalized_spans, token_bags
|
||||
|
||||
def _align_bags(predicted: List[Set[str]], gold: List[Set[str]]) -> List[float]:
|
||||
"""
|
||||
Takes gold and predicted answer sets and first finds the optimal 1-1 alignment
|
||||
between them and gets maximum metric values over all the answers.
|
||||
"""
|
||||
scores = np.zeros([len(gold), len(predicted)])
|
||||
for gold_index, gold_item in enumerate(gold):
|
||||
for pred_index, pred_item in enumerate(predicted):
|
||||
if match_numbers_if_present(gold_item, pred_item):
|
||||
scores[gold_index, pred_index] = f1_score(pred_item, gold_item)
|
||||
row_ind, col_ind = linear_sum_assignment(-scores)
|
||||
|
||||
max_scores = np.zeros([max(len(gold), len(predicted))])
|
||||
for row, column in zip(row_ind, col_ind):
|
||||
max_scores[row] = max(max_scores[row], scores[row, column])
|
||||
return max_scores
|
||||
|
||||
def match_numbers_if_present(gold_bag: Set[str], predicted_bag: Set[str]) -> bool:
|
||||
gold_numbers = set()
|
||||
predicted_numbers = set()
|
||||
for word in gold_bag:
|
||||
if is_number(word):
|
||||
gold_numbers.add(word)
|
||||
for word in predicted_bag:
|
||||
if is_number(word):
|
||||
predicted_numbers.add(word)
|
||||
if (not gold_numbers) or gold_numbers.intersection(predicted_numbers):
|
||||
return True
|
||||
return False
|
||||
|
||||
def f1_score(predicted_bag: Set[str], gold_bag: Set[str]) -> float:
|
||||
intersection = len(gold_bag.intersection(predicted_bag))
|
||||
if not predicted_bag:
|
||||
precision = 1.0
|
||||
else:
|
||||
precision = intersection / float(len(predicted_bag))
|
||||
if not gold_bag:
|
||||
recall = 1.0
|
||||
else:
|
||||
recall = intersection / float(len(gold_bag))
|
||||
f1 = (2 * precision * recall) / (precision + recall) if not (precision == 0.0 and recall == 0.0) else 0.0
|
||||
return f1
|
||||
|
||||
def load_data(file_path: str) -> List[Tuple[str, Dict[str, Any]]]:
|
||||
with open(file_path, mode="r") as file:
|
||||
data = json.load(file)
|
||||
return list(data.items())
|
||||
|
||||
async def evaluate_problem(question: str, passage: str, answers: List[Dict[str, Any]], graph: Callable) -> Tuple[str, str, float]:
|
||||
def answer_json_to_strings(answer: Dict[str, Any]) -> Tuple[Tuple[str, ...], str]:
|
||||
if "number" in answer and answer["number"]:
|
||||
return tuple([str(answer["number"])]), "number"
|
||||
elif "spans" in answer and answer["spans"]:
|
||||
return tuple(answer["spans"]), "span" if len(answer["spans"]) == 1 else "spans"
|
||||
elif "date" in answer:
|
||||
return (
|
||||
tuple(
|
||||
[
|
||||
"{0} {1} {2}".format(
|
||||
answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]
|
||||
)
|
||||
]
|
||||
),
|
||||
"date",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Answer type not found, should be one of number, spans or date at: {json.dumps(answer)}")
|
||||
|
||||
prediction = await graph(question, passage)
|
||||
|
||||
def get_f1_score(prediction: str, golden_answer: str) -> float:
|
||||
predicted_bags = answer_to_bags(prediction)
|
||||
gold_bags = answer_to_bags(golden_answer)
|
||||
|
||||
f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1])
|
||||
score = np.mean(f1_per_bag)
|
||||
return score
|
||||
|
||||
max_score = 0.0
|
||||
best_answer = None
|
||||
for answer in answers:
|
||||
golden_answer, _ = answer_json_to_strings(answer)
|
||||
golden_answer = golden_answer[0]
|
||||
score = get_f1_score(prediction, golden_answer)
|
||||
if score > max_score:
|
||||
max_score = score
|
||||
best_answer = golden_answer
|
||||
|
||||
return best_answer, prediction, max_score
|
||||
|
||||
async def evaluate_all_passages(annotations: List[Tuple[str, Dict[str, Any]]], graph: Callable, max_concurrent_tasks: int = 50) -> List[List[Any]]:
|
||||
semaphore = asyncio.Semaphore(max_concurrent_tasks)
|
||||
results = []
|
||||
|
||||
async def sem_evaluate(id: str, annotation: Dict[str, Any]):
|
||||
async with semaphore:
|
||||
passage = annotation["passage"]
|
||||
for qa_pair in annotation["qa_pairs"]:
|
||||
question = qa_pair["question"]
|
||||
answers = [qa_pair["answer"]]
|
||||
if "validated_answers" in qa_pair and qa_pair["validated_answers"]:
|
||||
answers.extend(qa_pair["validated_answers"])
|
||||
best_answer, prediction, score = await evaluate_problem(question, passage, answers, graph)
|
||||
results.append([id, question, prediction, best_answer, score])
|
||||
|
||||
tasks = [sem_evaluate(id, annotation) for id, annotation in annotations]
|
||||
await tqdm_asyncio.gather(*tasks, desc="Evaluating DROP passages", total=len(annotations))
|
||||
|
||||
return results
|
||||
|
||||
def save_results_to_csv(results: List[List[Any]], path: str) -> float:
|
||||
df = pd.DataFrame(results, columns=["id", "question", "prediction", "best_answer", "score"])
|
||||
average_score = df["score"].mean()
|
||||
|
||||
output_file = f"{path}/{average_score:.5f}.csv"
|
||||
df.to_csv(output_file, index=False)
|
||||
print(f"Results saved to {output_file}")
|
||||
|
||||
return average_score
|
||||
|
||||
async def drop_evaluation(graph: Callable, file_path: str, path: str) -> float:
|
||||
data = load_data(file_path)
|
||||
results = await evaluate_all_passages(data, graph, max_concurrent_tasks=20)
|
||||
average_score = save_results_to_csv(results, path=path)
|
||||
print(f"Average score on DROP dataset: {average_score:.5f}")
|
||||
return average_score
|
||||
|
|
@ -1,123 +1,25 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
from typing import Literal, Optional
|
||||
|
||||
import asyncio
|
||||
import aiofiles
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Callable, Set
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
from examples.ags.scripts.graph import HotpotQAGraph
|
||||
from examples.ags.scripts.operator import Format, GenerateOnContext
|
||||
from examples.ags.scripts.utils import get_hotpotqa
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from examples.ags.benchmark.utils import generate_random_indices
|
||||
|
||||
HOTPOTQA_PATH = "hotpotqa_1000.jsonl"
|
||||
def is_number(text: str) -> bool:
|
||||
try:
|
||||
float(text)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def normalize_answer(text):
|
||||
import re
|
||||
import string
|
||||
|
||||
def sort_json_by_key(input_path, output_path):
|
||||
with open(input_path) as f:
|
||||
data = [json.loads(line) for line in f]
|
||||
data.sort(key=lambda x: x["task_id"])
|
||||
with open(output_path, "w") as f:
|
||||
for line in data:
|
||||
f.write(json.dumps(line) + "\n")
|
||||
|
||||
|
||||
extract_supporting_sentences = GenerateOnContext(
|
||||
llm=LLM(), requirement="supporting sentences to get the final answers (split by newline)"
|
||||
)
|
||||
generate_on_context = GenerateOnContext(llm=LLM(), requirement="a concise answer without additional context")
|
||||
format = Format(llm=LLM())
|
||||
solver = HotpotQAGraph(
|
||||
name="solver",
|
||||
llm=LLM(),
|
||||
criteria="correctness, only concise answer, without additional context",
|
||||
HOTPOTQA_PATH=HOTPOTQA_PATH,
|
||||
)
|
||||
|
||||
ModeType = Literal["ags", "alpha_codium", "llm"]
|
||||
|
||||
|
||||
async def llm_generate(id):
|
||||
dp = get_hotpotqa(HOTPOTQA_PATH)[id]
|
||||
paragraphs = [item[1] for item in dp["context"] if isinstance(item[1], list)]
|
||||
context_str = "\n".join(" ".join(paragraph) for paragraph in paragraphs)
|
||||
|
||||
supporting_sentences = await extract_supporting_sentences(dp["question"], context_str)
|
||||
supporting_sentences_str = "\n".join(supporting_sentences.get("solution"))
|
||||
|
||||
answer_result = await generate_on_context(dp["question"], supporting_sentences_str)
|
||||
answer_result = answer_result.get("solution")
|
||||
|
||||
answer_formated = await format(dp["question"], answer_result)
|
||||
sample_dict = dict(
|
||||
task_id=id,
|
||||
answer=answer_formated.get("solution"),
|
||||
supporting_sentences=supporting_sentences.get("solution").split("\n"),
|
||||
)
|
||||
return sample_dict
|
||||
|
||||
|
||||
async def route_generate(mode: ModeType, id):
|
||||
if mode == "ags":
|
||||
sample_dict = await solver(id)
|
||||
elif mode == "llm":
|
||||
sample_dict = await llm_generate(id)
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
|
||||
return sample_dict
|
||||
|
||||
|
||||
async def sample_generate(id, result_path: str = "samples.jsonl", mode: ModeType = "llm"):
|
||||
sample_dict = await route_generate(mode, id)
|
||||
async with aiofiles.open(result_path, mode="a") as f:
|
||||
await f.write(json.dumps(sample_dict) + "\n")
|
||||
# sort_json_by_key(result_path, result_path)
|
||||
|
||||
|
||||
async def samples_generate(
|
||||
mode: ModeType, data_path: str = HOTPOTQA_PATH, result_path: str = "samples.jsonl", max_concurrency: int = 50
|
||||
):
|
||||
ids = list(get_hotpotqa(HOTPOTQA_PATH).keys())
|
||||
|
||||
file_lock = asyncio.Lock()
|
||||
semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def answer_and_write(mode: ModeType, id) -> Optional[str]:
|
||||
async with semaphore:
|
||||
try:
|
||||
sample_dict = await route_generate(mode, id)
|
||||
except Exception:
|
||||
return id
|
||||
async with file_lock:
|
||||
async with aiofiles.open(result_path, mode="a") as f:
|
||||
await f.write(json.dumps(sample_dict) + "\n")
|
||||
return None
|
||||
|
||||
tasks = [answer_and_write(mode, id) for id in ids]
|
||||
results = await asyncio.gather(*tasks)
|
||||
failed_ids = [id for id in results if id is not None]
|
||||
|
||||
if failed_ids:
|
||||
logger.info(failed_ids)
|
||||
for id in failed_ids:
|
||||
try:
|
||||
await sample_generate(id, result_path, mode)
|
||||
failed_ids.remove(id)
|
||||
except Exception:
|
||||
logger.error(f"Failed to generate sample for id: {id}")
|
||||
|
||||
sort_json_by_key(result_path, result_path)
|
||||
|
||||
if not failed_ids:
|
||||
eval_path = result_path[:-6] + "_eval.json"
|
||||
logger.info(eval(result_path, data_path, eval_path))
|
||||
|
||||
|
||||
def normalize_answer(s):
|
||||
def remove_articles(text):
|
||||
return re.sub(r"\b(a|an|the)\b", " ", text)
|
||||
|
||||
|
|
@ -131,43 +33,143 @@ def normalize_answer(s):
|
|||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
def tokenize(text):
|
||||
return re.split(" |-", text)
|
||||
|
||||
def normalize_number(text: str) -> str:
|
||||
if is_number(text):
|
||||
return str(float(text))
|
||||
else:
|
||||
return text
|
||||
|
||||
def exact_match_score(prediction, ground_truth):
|
||||
return normalize_answer(prediction) == normalize_answer(ground_truth)
|
||||
parts = [
|
||||
white_space_fix(remove_articles(normalize_number(remove_punc(lower(token)))))
|
||||
for token in tokenize(text)
|
||||
]
|
||||
parts = [part for part in parts if part.strip()]
|
||||
normalized = " ".join(parts).strip()
|
||||
return normalized
|
||||
|
||||
def answer_to_bags(answer: str) -> Set[str]:
|
||||
raw_spans = [answer]
|
||||
|
||||
def eval(prediction_file, gold_file, eval_file):
|
||||
# if existing eval file
|
||||
if os.path.exists(eval_file):
|
||||
# read the result
|
||||
with open(eval_file) as f:
|
||||
eval_results = [json.loads(line) for line in f]
|
||||
em = sum([result["em"] for result in eval_results])
|
||||
logger.info(f"EM: {em/len(eval_results)}")
|
||||
return
|
||||
normalized_spans = []
|
||||
token_bags = []
|
||||
for raw_span in raw_spans:
|
||||
normalized_span = normalize_answer(raw_span)
|
||||
normalized_spans.append(normalized_span)
|
||||
token_bags.append(set(normalized_span.split()))
|
||||
return normalized_spans, token_bags
|
||||
|
||||
sort_json_by_key(prediction_file, prediction_file)
|
||||
with open(prediction_file) as f:
|
||||
predictions = [json.loads(line) for line in f]
|
||||
def _align_bags(predicted: List[Set[str]], gold: List[Set[str]]) -> List[float]:
|
||||
"""
|
||||
Takes gold and predicted answer sets and first finds the optimal 1-1 alignment
|
||||
between them and gets maximum metric values over all the answers.
|
||||
"""
|
||||
scores = np.zeros([len(gold), len(predicted)])
|
||||
for gold_index, gold_item in enumerate(gold):
|
||||
for pred_index, pred_item in enumerate(predicted):
|
||||
if match_numbers_if_present(gold_item, pred_item):
|
||||
scores[gold_index, pred_index] = f1_score(pred_item, gold_item)
|
||||
row_ind, col_ind = linear_sum_assignment(-scores)
|
||||
|
||||
with open(gold_file) as f:
|
||||
golds = [json.loads(line) for line in f]
|
||||
max_scores = np.zeros([max(len(gold), len(predicted))])
|
||||
for row, column in zip(row_ind, col_ind):
|
||||
max_scores[row] = max(max_scores[row], scores[row, column])
|
||||
return max_scores
|
||||
|
||||
eval_results = []
|
||||
em = 0
|
||||
for prediction, gold in zip(predictions, golds):
|
||||
if prediction["task_id"] != gold["_id"]:
|
||||
raise ValueError(f"Task ID {gold['_id']} do not match")
|
||||
result = exact_match_score(prediction["answer"], gold["answer"])
|
||||
em += result
|
||||
eval_results.append(
|
||||
{"task_id": prediction["task_id"], "solution": prediction["answer"], "answer": gold["answer"], "em": result}
|
||||
)
|
||||
def match_numbers_if_present(gold_bag: Set[str], predicted_bag: Set[str]) -> bool:
|
||||
gold_numbers = set()
|
||||
predicted_numbers = set()
|
||||
for word in gold_bag:
|
||||
if is_number(word):
|
||||
gold_numbers.add(word)
|
||||
for word in predicted_bag:
|
||||
if is_number(word):
|
||||
predicted_numbers.add(word)
|
||||
if (not gold_numbers) or gold_numbers.intersection(predicted_numbers):
|
||||
return True
|
||||
return False
|
||||
|
||||
with open(eval_file, "w") as f:
|
||||
for line in eval_results:
|
||||
f.write(json.dumps(line) + "\n")
|
||||
def f1_score(predicted_bag: Set[str], gold_bag: Set[str]) -> float:
|
||||
intersection = len(gold_bag.intersection(predicted_bag))
|
||||
if not predicted_bag:
|
||||
precision = 1.0
|
||||
else:
|
||||
precision = intersection / float(len(predicted_bag))
|
||||
if not gold_bag:
|
||||
recall = 1.0
|
||||
else:
|
||||
recall = intersection / float(len(gold_bag))
|
||||
f1 = (2 * precision * recall) / (precision + recall) if not (precision == 0.0 and recall == 0.0) else 0.0
|
||||
return f1
|
||||
|
||||
logger.info(f"EM: {em/len(predictions)}")
|
||||
async def load_data(file_path: str, samples=20) -> List[dict]:
|
||||
data = []
|
||||
async with aiofiles.open(file_path, mode="r") as file:
|
||||
async for line in file:
|
||||
data.append(json.loads(line))
|
||||
random_indices = generate_random_indices(len(data), samples)
|
||||
data = [data[i] for i in random_indices]
|
||||
return data
|
||||
|
||||
async def evaluate_problem(input: str, context_str: str, graph: Callable, expected_output: str):
|
||||
max_retries = 5
|
||||
retries = 0
|
||||
|
||||
while retries < max_retries:
|
||||
try:
|
||||
prediction, supporting_sentences = await graph(input, context_str) if graph else "None"
|
||||
predicted_bags = answer_to_bags(prediction)
|
||||
gold_bags = answer_to_bags(expected_output)
|
||||
|
||||
f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1])
|
||||
score = np.mean(f1_per_bag)
|
||||
|
||||
break
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
print(f"Error generating prediction: {e}. Retrying... ({retries}/{max_retries})")
|
||||
|
||||
if retries == max_retries:
|
||||
print("Maximum retries reached. Skipping this sample.")
|
||||
prediction = None
|
||||
supporting_sentences = None
|
||||
score = 0
|
||||
break
|
||||
|
||||
return input, prediction, expected_output, supporting_sentences, score
|
||||
|
||||
async def evaluate_all_problems(data: List[dict], graph: Callable, max_concurrent_tasks: int = 50):
|
||||
semaphore = asyncio.Semaphore(max_concurrent_tasks)
|
||||
|
||||
async def sem_evaluate(problem):
|
||||
async with semaphore:
|
||||
input_text = problem["question"]
|
||||
expected_output = problem["answer"]
|
||||
paragraphs = [item[1] for item in problem["context"] if isinstance(item[1], list)]
|
||||
context_str = "\n".join(" ".join(paragraph) for paragraph in paragraphs)
|
||||
return await evaluate_problem(input_text, context_str, graph, expected_output)
|
||||
|
||||
tasks = [sem_evaluate(problem) for problem in data]
|
||||
|
||||
return await tqdm_asyncio.gather(*tasks, desc="Evaluating HotpotQA problems", total=len(data))
|
||||
|
||||
def save_results_to_csv(results: List[Tuple[str, str, str, str, float]], path: str) -> float:
|
||||
df = pd.DataFrame(
|
||||
results, columns=["question", "prediction", "expected_output", "supporting_sentences", "score"]
|
||||
)
|
||||
average_score = df["score"].mean()
|
||||
|
||||
output_file = f"{path}/{average_score:.5f}.csv"
|
||||
df.to_csv(output_file, index=False)
|
||||
print(f"Results saved to {output_file}")
|
||||
|
||||
return average_score
|
||||
|
||||
async def hotpotqa_evaluation(graph: Callable, file_path: str, samples: int, path: str) -> float:
|
||||
data = await load_data(file_path, samples)
|
||||
results = await evaluate_all_problems(data, graph, max_concurrent_tasks=20)
|
||||
average_score = save_results_to_csv(results, path=path)
|
||||
print(f"Average score on HotpotQA dataset: {average_score:.5f}")
|
||||
return average_score
|
||||
|
|
|
|||
|
|
@ -1,171 +1,112 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 7/7/2024 17:07 PM
|
||||
# @Author : didi
|
||||
# @Desc : test on human eval graph
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Literal, Optional
|
||||
|
||||
import asyncio
|
||||
import aiofiles
|
||||
from evalplus.data import get_human_eval_plus
|
||||
import pandas as pd
|
||||
from typing import List, Tuple, Callable
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from examples.ags.scripts.graph import HumanEvalGraph
|
||||
from examples.ags.scripts.operator import GenerateCodeBlock
|
||||
from examples.ags.scripts.utils import sort_json_by_key
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import add_jsonl_file, read_json_file
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from examples.ags.benchmark.utils import generate_random_indices
|
||||
|
||||
generate_code_block = GenerateCodeBlock(llm=LLM())
|
||||
solver = HumanEvalGraph(name="solver", llm=LLM(), criteria="correctness, efficiency, readability", vote_count=1)
|
||||
PASS = "pass"
|
||||
FAIL = "fail"
|
||||
|
||||
ModeType = Literal["ags", "alpha_codium", "llm"]
|
||||
async def load_data(file_path: str, samples=1) -> List[dict]:
|
||||
data = []
|
||||
async with aiofiles.open(file_path, mode="r") as file:
|
||||
async for line in file:
|
||||
data.append(json.loads(line))
|
||||
random_indices = generate_random_indices(len(data), samples)
|
||||
data = [data[i] for i in random_indices]
|
||||
return data
|
||||
|
||||
async def check_solution(solution, test_cases, entry_point):
|
||||
# Define a local dictionary to execute the solution
|
||||
local_dict = {}
|
||||
exec("from typing import List\n\n" + solution, {}, local_dict)
|
||||
|
||||
async def llm_generate(id):
|
||||
case = get_human_eval_plus()[f"{id}"]
|
||||
solution_result = await generate_code_block(case["prompt"], case["entry_point"])
|
||||
sample_dict = dict(task_id=case["task_id"], solution=solution_result["code_solution"])
|
||||
return sample_dict
|
||||
# Ensure the entry point function is defined
|
||||
if entry_point not in local_dict:
|
||||
raise ValueError(f"Function {entry_point} is not defined in the solution.")
|
||||
|
||||
details = [False for _ in range(len(test_cases))]
|
||||
|
||||
async def ags_generate(id, ensemble_count: int = 5):
|
||||
case = get_human_eval_plus()[f"{id}"]
|
||||
solution_result = await solver(case["prompt"], case["entry_point"], ensemble_count=ensemble_count)
|
||||
sample_dict = dict(task_id=case["task_id"], solution=solution_result["final_solution"])
|
||||
return sample_dict
|
||||
# Check each test case
|
||||
for i, test in enumerate(test_cases):
|
||||
# Replace 'candidate' with the actual function call
|
||||
test_expr = test.replace("candidate", entry_point)
|
||||
try:
|
||||
# Evaluate the test case
|
||||
if eval(test_expr, {}, local_dict):
|
||||
details[i] = True
|
||||
except Exception as e:
|
||||
print(f"Error evaluating test case '{test}': {e}")
|
||||
|
||||
if all(details):
|
||||
return PASS, details
|
||||
|
||||
async def alpha_codium_generate(id, ensemble_count: int = 1):
|
||||
case = get_human_eval_plus()[f"{id}"]
|
||||
solution_result = await solver.alpha_codium(case["task_id"], case["prompt"], ensemble_count=ensemble_count)
|
||||
sample_dict = dict(task_id=case["task_id"], solution=solution_result["final_solution"])
|
||||
return sample_dict
|
||||
return FAIL, details
|
||||
|
||||
async def evaluate_problem(data: dict, graph: Callable) -> Tuple[str, str, str, int]:
|
||||
max_retries = 5
|
||||
retries = 0
|
||||
|
||||
async def route_generate(mode: ModeType, id: str):
|
||||
token_usage = 0
|
||||
money_usage = 0
|
||||
if mode == "ags":
|
||||
sample_dict = await ags_generate(id)
|
||||
elif mode == "alpha_codium":
|
||||
sample_dict = await alpha_codium_generate(id, 5)
|
||||
elif mode == "llm":
|
||||
sample_dict = await llm_generate(id)
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
return sample_dict, token_usage, money_usage
|
||||
while retries < max_retries:
|
||||
try:
|
||||
solution = await graph(data["prompt"]) if graph else "None"
|
||||
ret = await check_solution(solution, data["test_cases"], data["entry_point"])
|
||||
|
||||
score = 1 if ret[0] == PASS else 0
|
||||
break
|
||||
|
||||
async def sample_generate(id, result_path: str = "samples.jsonl", mode: ModeType = "ags"):
|
||||
sample_dict, token_usage, money_usage = await route_generate(mode, id)
|
||||
add_jsonl_file(result_path, [sample_dict])
|
||||
sort_json_by_key(result_path, result_path)
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
print(f"Error generating prediction: {e}. Retrying... ({retries}/{max_retries})")
|
||||
|
||||
if retries == max_retries:
|
||||
print("Maximum retries reached. Skipping this sample.")
|
||||
solution = None
|
||||
ret = (FAIL, [])
|
||||
score = 0
|
||||
break
|
||||
|
||||
async def samples_generate(mode: ModeType, result_path: str = "samples.jsonl", max_concurrency: int = 50):
|
||||
ids = list(get_human_eval_plus().keys())
|
||||
file_lock = asyncio.Lock()
|
||||
semaphore = asyncio.Semaphore(max_concurrency)
|
||||
return data["prompt"], solution, ret[1], score
|
||||
|
||||
async def solve_and_write(id: str, mode: ModeType) -> Optional[str]:
|
||||
async def evaluate_all_problems(data: List[dict], graph: Callable, max_concurrent_tasks: int = 50) -> List[Tuple[str, str, str, int]]:
|
||||
semaphore = asyncio.Semaphore(max_concurrent_tasks)
|
||||
|
||||
async def sem_evaluate(problem):
|
||||
async with semaphore:
|
||||
try:
|
||||
sample_dict, token_usage, money_usage = await route_generate(mode, id)
|
||||
except Exception:
|
||||
return id
|
||||
async with file_lock:
|
||||
async with aiofiles.open(result_path, mode="a") as f:
|
||||
await f.write(json.dumps(sample_dict) + "\n")
|
||||
return None
|
||||
return await evaluate_problem(problem, graph)
|
||||
|
||||
tasks = [solve_and_write(id, mode) for id in ids]
|
||||
results = await asyncio.gather(*tasks)
|
||||
failed_tasks = [task_id for task_id in results if task_id is not None]
|
||||
tasks = [sem_evaluate(problem) for problem in data]
|
||||
|
||||
if failed_tasks:
|
||||
logger.info(failed_tasks)
|
||||
return await tqdm_asyncio.gather(*tasks, desc="Evaluating HumanEval problems", total=len(data))
|
||||
|
||||
async def retry_failed_task(task_id):
|
||||
try:
|
||||
await sample_generate(task_id, result_path, mode)
|
||||
return None
|
||||
except Exception:
|
||||
logger.error(f"{task_id} fail")
|
||||
return task_id
|
||||
def save_results_to_jsonl(results: List[Tuple[str, str, str, int]], path: str) -> float:
|
||||
avg_score = 0
|
||||
|
||||
retry_results = await asyncio.gather(*[retry_failed_task(task_id) for task_id in failed_tasks])
|
||||
failed_tasks = [task_id for task_id in retry_results if task_id is not None]
|
||||
with open(path, "w") as f:
|
||||
for result in results:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"question": result[0],
|
||||
"prediction": result[1],
|
||||
"test_case_details": result[2],
|
||||
"score": result[3],
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
avg_score += result[3]
|
||||
print(f"Results saved to {path}")
|
||||
avg_score /= len(results)
|
||||
|
||||
sort_json_by_key(result_path, result_path)
|
||||
return avg_score
|
||||
|
||||
if not failed_tasks:
|
||||
if automatic_evalplus(result_path):
|
||||
eval_path = result_path[:-6] + "_eval_results.json"
|
||||
unpassed_exapmle = extract_failure_tests(eval_path)
|
||||
logger.info(unpassed_exapmle)
|
||||
else:
|
||||
logger.info(failed_tasks)
|
||||
|
||||
|
||||
@handle_exception(exception_type=subprocess.CalledProcessError, exception_msg="sanitize error", default_return=None)
|
||||
def automatic_sanitize(result_path: str = "samples.jsonl") -> Optional[str]:
|
||||
"""
|
||||
在命令行中自动执行 evalplus.sanitize --samples result_path
|
||||
返回result_path前缀加上"-sanitized.jsonl"
|
||||
"""
|
||||
command = ["evalplus.sanitize", "--samples", result_path]
|
||||
|
||||
subprocess.run(command, check=True)
|
||||
|
||||
base_name = os.path.splitext(result_path)[0]
|
||||
sanitized_path = f"{base_name}-sanitized.jsonl"
|
||||
|
||||
return sanitized_path
|
||||
|
||||
|
||||
@handle_exception(
|
||||
exception_type=subprocess.CalledProcessError,
|
||||
exception_msg="Error in automatic_evalplus function",
|
||||
default_return=False,
|
||||
)
|
||||
def automatic_evalplus(result_path: str = "samples.jsonl") -> bool:
|
||||
"""
|
||||
在命令行中自动执行 evalplus.evaluate --dataset humaneval --samples samples.jsonl --parallel 2 --base-only
|
||||
"""
|
||||
command = [
|
||||
sys.executable, # 使用当前 Python 解释器
|
||||
"-m",
|
||||
"evalplus.evaluate",
|
||||
"--dataset",
|
||||
"humaneval",
|
||||
"--samples",
|
||||
result_path,
|
||||
"--parallel",
|
||||
"2",
|
||||
"--base-only",
|
||||
]
|
||||
|
||||
result = subprocess.run(command, check=True, capture_output=True, text=True)
|
||||
logger.info(f"ouptput: \n {result.stdout}")
|
||||
return True
|
||||
|
||||
|
||||
def extract_failure_tests(file_path: str = "samples_eval_results.json"):
|
||||
task_results = read_json_file(file_path)
|
||||
|
||||
failed_tests = []
|
||||
for task in task_results["eval"].values():
|
||||
if task[0]["base_status"] == "fail":
|
||||
failed_test = {
|
||||
"task_id": task[0]["task_id"],
|
||||
}
|
||||
failed_tests.append(failed_test)
|
||||
logger.info(f"length of failed tests: {len(failed_tests)}")
|
||||
|
||||
return failed_tests
|
||||
async def humaneval_evaluation(graph: Callable, file_path: str, samples: int, path: str) -> float:
|
||||
data = await load_data(file_path, samples)
|
||||
results = await evaluate_all_problems(data, graph, max_concurrent_tasks=20)
|
||||
average_score = save_results_to_jsonl(results, path=path)
|
||||
print(f"Average score on HumanEval dataset: {average_score:.5f}")
|
||||
return average_score
|
||||
|
|
|
|||
277
examples/ags/benchmark/math.py
Normal file
277
examples/ags/benchmark/math.py
Normal file
|
|
@ -0,0 +1,277 @@
|
|||
import re
|
||||
import regex
|
||||
from sympy import N, simplify
|
||||
from sympy.parsing.latex import parse_latex
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
from math import isclose
|
||||
import multiprocessing
|
||||
import json
|
||||
import asyncio
|
||||
import aiofiles
|
||||
import pandas as pd
|
||||
from typing import Optional, List, Tuple, Callable, Union
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from examples.ags.benchmark.utils import generate_random_indices
|
||||
|
||||
def extract_answer(text: str) -> str:
|
||||
# Look for the answer within \boxed{...}
|
||||
boxed_match = re.search(r"\\boxed{(.*?)}", text)
|
||||
if boxed_match:
|
||||
return boxed_match.group(1)
|
||||
|
||||
# If no \boxed{...}, return the last sentence
|
||||
sentences = text.split(".")
|
||||
return sentences[-1].strip() if sentences else ""
|
||||
|
||||
def parse_digits(num):
|
||||
# format: 234.23 || 23%
|
||||
num = regex.sub(",", "", str(num))
|
||||
try:
|
||||
return float(num)
|
||||
except:
|
||||
if num.endswith("%"):
|
||||
num = num[:-1]
|
||||
if num.endswith("\\"):
|
||||
num = num[:-1]
|
||||
try:
|
||||
return float(num) / 100
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
def is_digit(num):
|
||||
# paired with parse_digits
|
||||
return parse_digits(num) is not None
|
||||
|
||||
def symbolic_equal(a, b):
|
||||
def _parse(s):
|
||||
for f in [parse_latex, parse_expr]:
|
||||
try:
|
||||
return f(s)
|
||||
except:
|
||||
pass
|
||||
return s
|
||||
|
||||
a = _parse(a)
|
||||
b = _parse(b)
|
||||
|
||||
try:
|
||||
if simplify(a - b) == 0:
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
if isclose(N(a), N(b), abs_tol=1e-3):
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
def call_with_timeout(func, *args, timeout=5, **kwargs):
|
||||
output_queue = multiprocessing.Queue()
|
||||
process_args = args + (output_queue,)
|
||||
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
|
||||
process.start()
|
||||
process.join(timeout)
|
||||
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
process.join()
|
||||
return False
|
||||
|
||||
return output_queue.get()
|
||||
|
||||
def math_equal(
|
||||
prediction: Union[bool, float, str],
|
||||
reference: Union[float, str],
|
||||
include_percentage: bool = True,
|
||||
is_close: bool = True,
|
||||
timeout: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Exact match of math if and only if:
|
||||
1. numerical equal: both can convert to float and are equal
|
||||
2. symbolic equal: both can convert to sympy expression and are equal
|
||||
"""
|
||||
if str(prediction) == str(reference):
|
||||
return True
|
||||
|
||||
try: # 1. numerical equal
|
||||
if is_digit(prediction) and is_digit(reference):
|
||||
prediction = parse_digits(prediction)
|
||||
reference = parse_digits(reference)
|
||||
# number questions
|
||||
if include_percentage:
|
||||
gt_result = [reference / 100, reference, reference * 100]
|
||||
else:
|
||||
gt_result = [reference]
|
||||
for item in gt_result:
|
||||
try:
|
||||
if is_close:
|
||||
if isclose(item, prediction, abs_tol=1e-3):
|
||||
return True
|
||||
else:
|
||||
if item == prediction:
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
except:
|
||||
pass
|
||||
|
||||
if not prediction and prediction not in [0, False]:
|
||||
return False
|
||||
|
||||
# 2. symbolic equal
|
||||
reference = str(reference).strip()
|
||||
prediction = str(prediction).strip()
|
||||
|
||||
if (
|
||||
regex.match(r"(\(|\[).+(\)|\])", prediction) is not None
|
||||
and regex.match(r"(\(|\[).+(\)|\])", reference) is not None
|
||||
):
|
||||
pred_parts = prediction[1:-1].split(",")
|
||||
ref_parts = reference[1:-1].split(",")
|
||||
if len(pred_parts) == len(ref_parts):
|
||||
if all(
|
||||
[
|
||||
math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close)
|
||||
for i in range(len(pred_parts))
|
||||
]
|
||||
):
|
||||
return True
|
||||
|
||||
if (
|
||||
(prediction.startswith("\\begin{pmatrix}") or prediction.startswith("\\begin{bmatrix}"))
|
||||
and (prediction.endswith("\\end{pmatrix}") or prediction.endswith("\\end{bmatrix}"))
|
||||
and (reference.startswith("\\begin{pmatrix}") or reference.startswith("\\begin{bmatrix}"))
|
||||
and (reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}"))
|
||||
):
|
||||
pred_lines = [
|
||||
line.strip()
|
||||
for line in prediction[len("\\begin{pmatrix}") : -len("\\end{pmatrix}")].split("\\\\")
|
||||
if line.strip()
|
||||
]
|
||||
ref_lines = [
|
||||
line.strip()
|
||||
for line in reference[len("\\begin{pmatrix}") : -len("\\end{pmatrix}")].split("\\\\")
|
||||
if line.strip()
|
||||
]
|
||||
matched = True
|
||||
if len(pred_lines) == len(ref_lines):
|
||||
for pred_line, ref_line in zip(pred_lines, ref_lines):
|
||||
pred_parts = pred_line.split("&")
|
||||
ref_parts = ref_line.split("&")
|
||||
if len(pred_parts) == len(ref_parts):
|
||||
if not all(
|
||||
[
|
||||
math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close)
|
||||
for i in range(len(pred_parts))
|
||||
]
|
||||
):
|
||||
matched = False
|
||||
break
|
||||
else:
|
||||
matched = False
|
||||
if not matched:
|
||||
break
|
||||
else:
|
||||
matched = False
|
||||
if matched:
|
||||
return True
|
||||
|
||||
if prediction.count("=") == 1 and reference.count("=") == 1:
|
||||
pred = prediction.split("=")
|
||||
pred = f"{pred[0].strip()} - ({pred[1].strip()})"
|
||||
ref = reference.split("=")
|
||||
ref = f"{ref[0].strip()} - ({ref[1].strip()})"
|
||||
if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
|
||||
return True
|
||||
elif prediction.count("=") == 1 and len(prediction.split("=")[0].strip()) <= 2 and "=" not in reference:
|
||||
if math_equal(prediction.split("=")[1], reference, include_percentage, is_close):
|
||||
return True
|
||||
elif reference.count("=") == 1 and len(reference.split("=")[0].strip()) <= 2 and "=" not in prediction:
|
||||
if math_equal(prediction, reference.split("=")[1], include_percentage, is_close):
|
||||
return True
|
||||
|
||||
# symbolic equal with sympy
|
||||
if timeout:
|
||||
if call_with_timeout(symbolic_equal, prediction, reference):
|
||||
return True
|
||||
else:
|
||||
if symbolic_equal(prediction, reference):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def calculate_score(expected_output: str, prediction: str) -> int:
|
||||
expected_answer = extract_answer(expected_output)
|
||||
predicted_answer = extract_answer(prediction)
|
||||
|
||||
return 1 if math_equal(predicted_answer, expected_answer) else 0
|
||||
|
||||
async def load_data(file_path: str, samples: int = 200) -> List[dict]:
|
||||
data = []
|
||||
async with aiofiles.open(file_path, mode="r") as file:
|
||||
async for line in file:
|
||||
data.append(json.loads(line))
|
||||
random_indices = generate_random_indices(len(data), samples)
|
||||
data = [data[i] for i in random_indices]
|
||||
return data
|
||||
|
||||
def save_results_to_csv(results: List[Tuple[str, str, str, int, str]], path: str) -> float:
|
||||
df = pd.DataFrame(results, columns=["question", "prediction", "expected_output", "score", "cost"])
|
||||
average_score = df["score"].mean()
|
||||
|
||||
output_file = f"{path}/{average_score:.5f}.csv"
|
||||
df.to_csv(output_file, index=False)
|
||||
print(f"Results saved to {output_file}")
|
||||
return average_score
|
||||
|
||||
async def evaluate_problem(problem: dict, graph: Callable) -> Tuple[str, str, str, int, str]:
|
||||
input_text = problem["problem"]
|
||||
expected_output = problem["solution"]
|
||||
max_retries = 5
|
||||
retries = 0
|
||||
|
||||
while retries < max_retries:
|
||||
try:
|
||||
prediction = await graph(input_text)
|
||||
cost = prediction[1]
|
||||
output = prediction[0]["solution"]
|
||||
|
||||
score = calculate_score(expected_output, output)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
print(f"Error generating prediction: {e}. Retrying... ({retries}/{max_retries})")
|
||||
|
||||
if retries == max_retries:
|
||||
print("Maximum retries reached. Skipping this sample.")
|
||||
output = None
|
||||
cost = None
|
||||
score = 0
|
||||
break
|
||||
|
||||
return input_text, output, expected_output, score, cost
|
||||
|
||||
async def evaluate_all_problems(data: List[dict], graph: Callable, max_concurrent_tasks: int = 20) -> List[Tuple[str, str, str, int, str]]:
|
||||
semaphore = asyncio.Semaphore(max_concurrent_tasks)
|
||||
|
||||
async def sem_evaluate(problem):
|
||||
async with semaphore:
|
||||
return await evaluate_problem(problem, graph)
|
||||
|
||||
tasks = [sem_evaluate(problem) for problem in data]
|
||||
|
||||
return await tqdm_asyncio.gather(*tasks, desc="Evaluating MATH problems", total=len(data))
|
||||
|
||||
async def math_evaluation(graph: Callable, file_path: str, samples: int, path: str) -> float:
|
||||
data = await load_data(file_path, samples)
|
||||
results = await evaluate_all_problems(data, graph, max_concurrent_tasks=20)
|
||||
average_score = save_results_to_csv(results, path=path)
|
||||
print(f"Average score on MATH dataset: {average_score:.5f}")
|
||||
return average_score
|
||||
105
examples/ags/benchmark/mbpp.py
Normal file
105
examples/ags/benchmark/mbpp.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
import json
|
||||
import asyncio
|
||||
import aiofiles
|
||||
import pandas as pd
|
||||
from typing import List, Tuple, Callable
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from examples.ags.benchmark.utils import generate_random_indices
|
||||
|
||||
PASS = "pass"
|
||||
FAIL = "fail"
|
||||
|
||||
async def load_data(file_path: str, samples=1) -> List[dict]:
|
||||
data = []
|
||||
async with aiofiles.open(file_path, mode="r") as file:
|
||||
async for line in file:
|
||||
data.append(json.loads(line))
|
||||
random_indices = generate_random_indices(len(data), samples)
|
||||
data = [data[i] for i in random_indices]
|
||||
return data
|
||||
|
||||
async def check_solution(solution, test_cases, timeout=1):
|
||||
# Define a local dictionary to execute the solution
|
||||
local_dict = {}
|
||||
exec(solution, {}, local_dict)
|
||||
|
||||
details = [False for _ in range(len(test_cases))]
|
||||
|
||||
async def evaluate_test(test):
|
||||
# Delete 'assert' from test
|
||||
test_expr = test.replace("assert ", "")
|
||||
try:
|
||||
# Evaluate the test case with timeout
|
||||
await asyncio.wait_for(asyncio.to_thread(eval, test_expr, {}, local_dict), timeout)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
print(f"Test case '{test}' timed out.")
|
||||
except Exception as e:
|
||||
print(f"Error evaluating test case '{test}': {e}")
|
||||
return False
|
||||
|
||||
# Check each test case
|
||||
for i, test in enumerate(test_cases):
|
||||
result = await evaluate_test(test)
|
||||
details[i] = result
|
||||
if not result:
|
||||
return FAIL, details
|
||||
|
||||
if all(details):
|
||||
return PASS, details
|
||||
|
||||
return FAIL, details
|
||||
|
||||
async def evaluate_problem(data: dict, graph: Callable) -> Tuple[str, str, str, int]:
|
||||
max_retries = 5
|
||||
retries = 0
|
||||
|
||||
while retries < max_retries:
|
||||
try:
|
||||
solution = await graph(data["prompt"]) if graph else "None"
|
||||
ret = await check_solution(solution, data["test_list"])
|
||||
|
||||
score = 1 if ret[0] == PASS else 0
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
print(f"Error generating prediction: {e}. Retrying... ({retries}/{max_retries})")
|
||||
|
||||
if retries == max_retries:
|
||||
print("Maximum retries reached. Skipping this sample.")
|
||||
solution = None
|
||||
ret = (FAIL, [])
|
||||
score = 0
|
||||
break
|
||||
|
||||
return data["prompt"], solution, ret[1], score
|
||||
|
||||
async def evaluate_all_problems(data: List[dict], graph: Callable, max_concurrent_tasks: int = 50) -> List[Tuple[str, str, str, int]]:
|
||||
semaphore = asyncio.Semaphore(max_concurrent_tasks)
|
||||
|
||||
async def sem_evaluate(problem):
|
||||
async with semaphore:
|
||||
return await evaluate_problem(problem, graph)
|
||||
|
||||
tasks = [sem_evaluate(problem) for problem in data]
|
||||
|
||||
return await tqdm_asyncio.gather(*tasks, desc="Evaluating MBPP problems", total=len(data))
|
||||
|
||||
def save_results_to_csv(results: List[Tuple[str, str, str, int]], path: str) -> float:
|
||||
df = pd.DataFrame(results, columns=["question", "prediction", "test_case_details", "score"])
|
||||
average_score = df["score"].mean()
|
||||
|
||||
output_file = f"{path}/{average_score:.5f}.csv"
|
||||
df.to_csv(output_file, index=False)
|
||||
print(f"Results saved to {output_file}")
|
||||
|
||||
return average_score
|
||||
|
||||
async def mbpp_evaluation(graph: Callable, file_path: str, samples: int, path: str) -> float:
|
||||
data = await load_data(file_path, samples)
|
||||
results = await evaluate_all_problems(data, graph, max_concurrent_tasks=20)
|
||||
average_score = save_results_to_csv(results, path=path)
|
||||
print(f"Average score on MBPP dataset: {average_score:.5f}")
|
||||
return average_score
|
||||
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue