Signed-off-by: kit <101046518@qq.com>
This commit is contained in:
kit 2024-10-09 21:26:59 +08:00
parent f2326c97f8
commit 2a19c174d9
5 changed files with 136 additions and 22 deletions

View file

@ -1,18 +1,51 @@
import json
from pathlib import Path
from metagpt.const import DABENCH_PATH
from examples.di.requirements_prompt import DABENCH
from metagpt.const import DABENCH_PATH
# This code is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py
def evaluate_accuracy_by_question(results):
correct = sum("correctness" in result and all(result["correctness"].values()) for result in results)
total = len(results)
return round(correct / total, 4) if total > 0 else 0
def evaluate_accuracy_by_sub_question(results):
correct = sum(sum(result["correctness"].values()) for result in results if "correctness" in result)
total = sum(len(result["correctness"]) for result in results if "correctness" in result)
return round(correct / total, 4) if total > 0 else 0
def evaluate_accuracy_proportional_by_sub_question_adjusted(results):
total_score = 0
for result in results:
if "correctness" in result:
sub_question_count = len(result["correctness"])
score_per_sub_question = 1 / sub_question_count if sub_question_count > 0 else 0
question_score = sum(result["correctness"].values()) * score_per_sub_question
total_score += question_score
return round(total_score / len(results), 4) if results else 0
class DABench:
def __init__(self, questions_file=Path(DABENCH_PATH) / 'da-dev-questions.jsonl', answers_file=Path(DABENCH_PATH) / 'da-dev-labels.jsonl', template = ''):
def __init__(
self,
questions_file=Path(DABENCH_PATH) / "da-dev-questions.jsonl",
answers_file=Path(DABENCH_PATH) / "da-dev-labels.jsonl",
template="",
):
# Read questions from a JSONL file
with open(questions_file, 'r') as file:
self.questions = {int(json.loads(line)['id']): json.loads(line) for line in file}
with open(questions_file, "r") as file:
self.questions = {int(json.loads(line)["id"]): json.loads(line) for line in file}
# Read answers from a JSONL file
with open(answers_file, 'r') as file:
self.answers = {int(json.loads(line)['id']): json.loads(line) for line in file}
with open(answers_file, "r") as file:
self.answers = {int(json.loads(line)["id"]): json.loads(line) for line in file}
self.template = template if template else DABENCH
def get_question(self, question_id):
"""Retrieve the question by its id."""
return self.questions.get(question_id, "Question not found.")
@ -20,7 +53,13 @@ class DABench:
def get_prompt(self, question_id):
"""Retrieve the question by its id."""
temp = self.get_question(question_id)
return self.template.format(question=temp['question'], constraints=temp['constraints'], format=temp['format'], file_name= str(DABENCH_PATH) + '/da-dev-tables/' + temp['file_name'], level=temp['level'],)
return self.template.format(
question=temp["question"],
constraints=temp["constraints"],
format=temp["format"],
file_name=str(DABENCH_PATH) + "/da-dev-tables/" + temp["file_name"],
level=temp["level"],
)
def get_answer(self, answer_id):
"""Retrieve the answer list by its id."""
@ -28,13 +67,13 @@ class DABench:
def eval(self, id, prediction):
"""Evaluate the prediction against the true label."""
true_label = self.get_answer(id)['common_answers']
true_label = self.get_answer(id)["common_answers"]
# Parse the prediction string into a dictionary of metric-value pairs
pred_dict = {}
for pred in prediction.split(','):
parts = pred.strip().split('[')
metric = parts[0].strip().replace('@', '')
value = float(parts[1].rstrip(']'))
for pred in prediction.split(","):
parts = pred.strip().split("[")
metric = parts[0].strip().replace("@", "")
value = float(parts[1].rstrip("]"))
pred_dict[metric] = value
# Sort the true labels to match the order of predictions
@ -49,9 +88,56 @@ class DABench:
return correct
def eval_all(self, id_list, predictions):
"""Evaluate all predictions and calculate accuracy rates."""
def sigle_eval(id, prediction):
"""Evaluate the prediction against the true label for a single question and return a dictionary indicating the correctness of each metric."""
true_label = self.get_answer(id)["common_answers"]
pred_dict = {}
# Parse the prediction string into a dictionary of metric-value pairs
for pred in prediction.split(","):
parts = pred.strip().split("[")
metric = parts[0].strip().replace("@", "")
value = float(parts[1].rstrip("]"))
pred_dict[metric] = value
# Initialize the correctness dictionary with False values
correctness = {metric: False for metric, _ in true_label}
# Check each metric's prediction against the true label
for metric, true_value in true_label:
if metric in pred_dict:
# Consider the prediction correct if it's within a small tolerance
if abs(pred_dict[metric] - float(true_value)) < 1e-6:
correctness[metric] = True
return correctness
results = []
for id, prediction in zip(id_list, predictions):
correct = sigle_eval(id, prediction)
results.append({"id": id, "correctness": correct})
# Calculate the three accuracy rates
accuracy_by_question = evaluate_accuracy_by_question(results)
accuracy_by_sub_question = evaluate_accuracy_by_sub_question(results)
proportional_accuracy_by_sub_question = evaluate_accuracy_proportional_by_sub_question_adjusted(results)
return {
"accuracy_by_question": accuracy_by_question,
"accuracy_by_sub_question": accuracy_by_sub_question,
"proportional_accuracy_by_sub_question": proportional_accuracy_by_sub_question,
}
if __name__ == "__main__":
DA = DABench()
id = 6
prediction = "@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]"
is_correct = DA.eval(id, prediction)
print(f"Prediction is {'correct' if is_correct else 'incorrect'}.")
id = [0, 5, 6]
prediction = [
"@mean_fare[34.89]",
"@correlation_coefficient[0.21]",
"@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]",
]
print(DA.eval_all(id, prediction))

View file

@ -7,5 +7,6 @@ ## Dataset-install
```
## How to run
```
python run_InfiAgent-DABench.py --id x
python run_InfiAgent-DABench_sigle.py --id x # Run a task
python run_InfiAgent-DABench_all.py # Run all tasks
```

View file

@ -0,0 +1,24 @@
import json
import fire
from DABench import DABench
from metagpt.roles.di.data_interpreter import DataInterpreter
async def main():
"""Evaluate all"""
DA = DABench()
id_list, predictions = [], []
for key, value in DA.answers.items():
requirement = DA.get_prompt(key)
di = DataInterpreter()
result = await di.run(requirement)
prediction = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"]
id_list.append(key)
predictions.append(prediction)
print(DA.eval_all(id_list, predictions))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -1,17 +1,20 @@
import json
import fire
from metagpt.roles.di.data_interpreter import DataInterpreter
import fire
from DABench import DABench
async def main(id=0):
from metagpt.roles.di.data_interpreter import DataInterpreter
async def main(id=5):
DA = DABench()
requirement = DA.get_prompt(id)
di = DataInterpreter()
result = await di.run(requirement)
prediction = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]['result']
prediction = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"]
is_correct = DA.eval(id, prediction)
print(f"Prediction is {'correct' if is_correct else 'incorrect'}.")
if __name__ == "__main__":
fire.Fire(main)

View file

@ -1,4 +1,4 @@
#InfiAgent-DABench requirements
# InfiAgent-DABench requirements
DABENCH = "You are required to {question} from a CSV file named {file_name}. {constraints}. The output format should be {format}. This task is categorized as {level}."
# ML-Benchmark requirements