Signed-off-by: kit <101046518@qq.com>
This commit is contained in:
kithib 2024-10-08 23:18:35 +08:00 committed by kit
parent bdba23e422
commit b5981d25a5
5 changed files with 89 additions and 0 deletions

View file

@ -0,0 +1,57 @@
import json
from pathlib import Path
from metagpt.const import DABENCH_PATH
from examples.di.requirements_prompt import DABENCH
class DABench:
def __init__(self, questions_file=Path(DABENCH_PATH) / 'da-dev-questions.jsonl', answers_file=Path(DABENCH_PATH) / 'da-dev-labels.jsonl', template = ''):
# Read questions from a JSONL file
with open(questions_file, 'r') as file:
self.questions = {int(json.loads(line)['id']): json.loads(line) for line in file}
# Read answers from a JSONL file
with open(answers_file, 'r') as file:
self.answers = {int(json.loads(line)['id']): json.loads(line) for line in file}
self.template = template if template else DABENCH
def get_question(self, question_id):
"""Retrieve the question by its id."""
return self.questions.get(question_id, "Question not found.")
def get_prompt(self, question_id):
"""Retrieve the question by its id."""
temp = self.get_question(question_id)
return self.template.format(question=temp['question'], constraints=temp['constraints'], format=temp['format'], file_name= str(DABENCH_PATH) + '/da-dev-tables/' + temp['file_name'], level=temp['level'],)
def get_answer(self, answer_id):
"""Retrieve the answer list by its id."""
return self.answers.get(answer_id, "Answer not found.")
def eval(self, id, prediction):
"""Evaluate the prediction against the true label."""
true_label = self.get_answer(id)['common_answers']
# Parse the prediction string into a dictionary of metric-value pairs
pred_dict = {}
for pred in prediction.split(','):
parts = pred.strip().split('[')
metric = parts[0].strip().replace('@', '')
value = float(parts[1].rstrip(']'))
pred_dict[metric] = value
# Sort the true labels to match the order of predictions
sorted_true_label = sorted(true_label, key=lambda x: x[0])
# Compare each prediction with the corresponding true label
correct = True
for metric, true_value in sorted_true_label:
if metric not in pred_dict or abs(pred_dict[metric] - float(true_value)) > 1e-6:
correct = False
break
return correct
if __name__ == "__main__":
DA = DABench()
id = 6
prediction = "@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]"
is_correct = DA.eval(id, prediction)
print(f"Prediction is {'correct' if is_correct else 'incorrect'}.")

View file

@ -0,0 +1,11 @@
# InfiAgent-DABench
This example is used to solve the InfiAgent-DABench using Data Interpreter (DI), and obtains 94.93% accuracy using gpt-4o.
## Dataset-install
```
git clone https://github.com/InfiAgent/InfiAgent.git
```
## How to run
```
python run_InfiAgent-DABench.py --id x
```

View file

@ -0,0 +1,17 @@
import json
import fire
from metagpt.roles.di.data_interpreter import DataInterpreter
from DABench import DABench
async def main(id=0):
DA = DABench()
requirement = DA.get_prompt(id)
di = DataInterpreter()
result = await di.run(requirement)
prediction = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]['result']
is_correct = DA.eval(id, prediction)
print(f"Prediction is {'correct' if is_correct else 'incorrect'}.")
if __name__ == "__main__":
fire.Fire(main)

View file

@ -1,3 +1,6 @@
#InfiAgent-DABench requirements
DABENCH = "You are required to {question} from a CSV file named {file_name}. {constraints}. The output format should be {format}. This task is categorized as {level}."
# ML-Benchmark requirements
IRIS_REQ = "Run data analysis on sklearn Iris dataset, include a plot"
WINES_RECOGNITION_REQ = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class with 20% as test set, and show prediction accuracy"