From 2a19c174d940533b855d676b3940203eff089c43 Mon Sep 17 00:00:00 2001 From: kit <101046518@qq.com> Date: Wed, 9 Oct 2024 21:26:59 +0800 Subject: [PATCH] 1 Signed-off-by: kit <101046518@qq.com> --- examples/di/InfiAgent-DABench/DABench.py | 118 +++++++++++++++--- examples/di/InfiAgent-DABench/README.md | 3 +- .../run_InfiAgent-DABench_all.py | 24 ++++ ...ench.py => run_InfiAgent-DABench_sigle.py} | 11 +- examples/di/requirements_prompt.py | 2 +- 5 files changed, 136 insertions(+), 22 deletions(-) create mode 100644 examples/di/InfiAgent-DABench/run_InfiAgent-DABench_all.py rename examples/di/InfiAgent-DABench/{run_InfiAgent-DABench.py => run_InfiAgent-DABench_sigle.py} (87%) diff --git a/examples/di/InfiAgent-DABench/DABench.py b/examples/di/InfiAgent-DABench/DABench.py index 4994acb78..0e3b7ccb4 100644 --- a/examples/di/InfiAgent-DABench/DABench.py +++ b/examples/di/InfiAgent-DABench/DABench.py @@ -1,18 +1,51 @@ import json from pathlib import Path -from metagpt.const import DABENCH_PATH + from examples.di.requirements_prompt import DABENCH +from metagpt.const import DABENCH_PATH + + +# This code is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py +def evaluate_accuracy_by_question(results): + correct = sum("correctness" in result and all(result["correctness"].values()) for result in results) + total = len(results) + return round(correct / total, 4) if total > 0 else 0 + + +def evaluate_accuracy_by_sub_question(results): + correct = sum(sum(result["correctness"].values()) for result in results if "correctness" in result) + total = sum(len(result["correctness"]) for result in results if "correctness" in result) + return round(correct / total, 4) if total > 0 else 0 + + +def evaluate_accuracy_proportional_by_sub_question_adjusted(results): + total_score = 0 + for result in results: + if "correctness" in result: + sub_question_count = len(result["correctness"]) + score_per_sub_question = 1 / sub_question_count if sub_question_count > 0 else 0 + question_score = sum(result["correctness"].values()) * score_per_sub_question + total_score += question_score + return round(total_score / len(results), 4) if results else 0 + + class DABench: - def __init__(self, questions_file=Path(DABENCH_PATH) / 'da-dev-questions.jsonl', answers_file=Path(DABENCH_PATH) / 'da-dev-labels.jsonl', template = ''): + def __init__( + self, + questions_file=Path(DABENCH_PATH) / "da-dev-questions.jsonl", + answers_file=Path(DABENCH_PATH) / "da-dev-labels.jsonl", + template="", + ): # Read questions from a JSONL file - with open(questions_file, 'r') as file: - self.questions = {int(json.loads(line)['id']): json.loads(line) for line in file} + with open(questions_file, "r") as file: + self.questions = {int(json.loads(line)["id"]): json.loads(line) for line in file} # Read answers from a JSONL file - with open(answers_file, 'r') as file: - self.answers = {int(json.loads(line)['id']): json.loads(line) for line in file} + with open(answers_file, "r") as file: + self.answers = {int(json.loads(line)["id"]): json.loads(line) for line in file} self.template = template if template else DABENCH + def get_question(self, question_id): """Retrieve the question by its id.""" return self.questions.get(question_id, "Question not found.") @@ -20,7 +53,13 @@ class DABench: def get_prompt(self, question_id): """Retrieve the question by its id.""" temp = self.get_question(question_id) - return self.template.format(question=temp['question'], constraints=temp['constraints'], format=temp['format'], file_name= str(DABENCH_PATH) + '/da-dev-tables/' + temp['file_name'], level=temp['level'],) + return self.template.format( + question=temp["question"], + constraints=temp["constraints"], + format=temp["format"], + file_name=str(DABENCH_PATH) + "/da-dev-tables/" + temp["file_name"], + level=temp["level"], + ) def get_answer(self, answer_id): """Retrieve the answer list by its id.""" @@ -28,13 +67,13 @@ class DABench: def eval(self, id, prediction): """Evaluate the prediction against the true label.""" - true_label = self.get_answer(id)['common_answers'] + true_label = self.get_answer(id)["common_answers"] # Parse the prediction string into a dictionary of metric-value pairs pred_dict = {} - for pred in prediction.split(','): - parts = pred.strip().split('[') - metric = parts[0].strip().replace('@', '') - value = float(parts[1].rstrip(']')) + for pred in prediction.split(","): + parts = pred.strip().split("[") + metric = parts[0].strip().replace("@", "") + value = float(parts[1].rstrip("]")) pred_dict[metric] = value # Sort the true labels to match the order of predictions @@ -49,9 +88,56 @@ class DABench: return correct + def eval_all(self, id_list, predictions): + """Evaluate all predictions and calculate accuracy rates.""" + + def sigle_eval(id, prediction): + """Evaluate the prediction against the true label for a single question and return a dictionary indicating the correctness of each metric.""" + true_label = self.get_answer(id)["common_answers"] + pred_dict = {} + + # Parse the prediction string into a dictionary of metric-value pairs + for pred in prediction.split(","): + parts = pred.strip().split("[") + metric = parts[0].strip().replace("@", "") + value = float(parts[1].rstrip("]")) + pred_dict[metric] = value + + # Initialize the correctness dictionary with False values + correctness = {metric: False for metric, _ in true_label} + + # Check each metric's prediction against the true label + for metric, true_value in true_label: + if metric in pred_dict: + # Consider the prediction correct if it's within a small tolerance + if abs(pred_dict[metric] - float(true_value)) < 1e-6: + correctness[metric] = True + + return correctness + + results = [] + for id, prediction in zip(id_list, predictions): + correct = sigle_eval(id, prediction) + results.append({"id": id, "correctness": correct}) + + # Calculate the three accuracy rates + accuracy_by_question = evaluate_accuracy_by_question(results) + accuracy_by_sub_question = evaluate_accuracy_by_sub_question(results) + proportional_accuracy_by_sub_question = evaluate_accuracy_proportional_by_sub_question_adjusted(results) + + return { + "accuracy_by_question": accuracy_by_question, + "accuracy_by_sub_question": accuracy_by_sub_question, + "proportional_accuracy_by_sub_question": proportional_accuracy_by_sub_question, + } + + if __name__ == "__main__": DA = DABench() - id = 6 - prediction = "@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]" - is_correct = DA.eval(id, prediction) - print(f"Prediction is {'correct' if is_correct else 'incorrect'}.") \ No newline at end of file + id = [0, 5, 6] + prediction = [ + "@mean_fare[34.89]", + "@correlation_coefficient[0.21]", + "@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]", + ] + print(DA.eval_all(id, prediction)) diff --git a/examples/di/InfiAgent-DABench/README.md b/examples/di/InfiAgent-DABench/README.md index 23eb6504f..8263e34f5 100644 --- a/examples/di/InfiAgent-DABench/README.md +++ b/examples/di/InfiAgent-DABench/README.md @@ -7,5 +7,6 @@ ## Dataset-install ``` ## How to run ``` -python run_InfiAgent-DABench.py --id x +python run_InfiAgent-DABench_sigle.py --id x # Run a task +python run_InfiAgent-DABench_all.py # Run all tasks ``` \ No newline at end of file diff --git a/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_all.py b/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_all.py new file mode 100644 index 000000000..852761c6a --- /dev/null +++ b/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_all.py @@ -0,0 +1,24 @@ +import json + +import fire +from DABench import DABench + +from metagpt.roles.di.data_interpreter import DataInterpreter + + +async def main(): + """Evaluate all""" + DA = DABench() + id_list, predictions = [], [] + for key, value in DA.answers.items(): + requirement = DA.get_prompt(key) + di = DataInterpreter() + result = await di.run(requirement) + prediction = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"] + id_list.append(key) + predictions.append(prediction) + print(DA.eval_all(id_list, predictions)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/examples/di/InfiAgent-DABench/run_InfiAgent-DABench.py b/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_sigle.py similarity index 87% rename from examples/di/InfiAgent-DABench/run_InfiAgent-DABench.py rename to examples/di/InfiAgent-DABench/run_InfiAgent-DABench_sigle.py index d68845b91..3dc4ad0e9 100644 --- a/examples/di/InfiAgent-DABench/run_InfiAgent-DABench.py +++ b/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_sigle.py @@ -1,17 +1,20 @@ import json -import fire -from metagpt.roles.di.data_interpreter import DataInterpreter +import fire from DABench import DABench -async def main(id=0): +from metagpt.roles.di.data_interpreter import DataInterpreter + + +async def main(id=5): DA = DABench() requirement = DA.get_prompt(id) di = DataInterpreter() result = await di.run(requirement) - prediction = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]['result'] + prediction = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"] is_correct = DA.eval(id, prediction) print(f"Prediction is {'correct' if is_correct else 'incorrect'}.") + if __name__ == "__main__": fire.Fire(main) diff --git a/examples/di/requirements_prompt.py b/examples/di/requirements_prompt.py index 4b9950611..1172e1fe5 100644 --- a/examples/di/requirements_prompt.py +++ b/examples/di/requirements_prompt.py @@ -1,4 +1,4 @@ -#InfiAgent-DABench requirements +# InfiAgent-DABench requirements DABENCH = "You are required to {question} from a CSV file named {file_name}. {constraints}. The output format should be {format}. This task is categorized as {level}." # ML-Benchmark requirements