diff --git a/examples/di/InfiAgent-DABench/DABench.py b/examples/di/InfiAgent-DABench/DABench.py new file mode 100644 index 000000000..4994acb78 --- /dev/null +++ b/examples/di/InfiAgent-DABench/DABench.py @@ -0,0 +1,57 @@ +import json +from pathlib import Path +from metagpt.const import DABENCH_PATH +from examples.di.requirements_prompt import DABENCH +class DABench: + def __init__(self, questions_file=Path(DABENCH_PATH) / 'da-dev-questions.jsonl', answers_file=Path(DABENCH_PATH) / 'da-dev-labels.jsonl', template = ''): + # Read questions from a JSONL file + with open(questions_file, 'r') as file: + self.questions = {int(json.loads(line)['id']): json.loads(line) for line in file} + + # Read answers from a JSONL file + with open(answers_file, 'r') as file: + self.answers = {int(json.loads(line)['id']): json.loads(line) for line in file} + + self.template = template if template else DABENCH + def get_question(self, question_id): + """Retrieve the question by its id.""" + return self.questions.get(question_id, "Question not found.") + + def get_prompt(self, question_id): + """Retrieve the question by its id.""" + temp = self.get_question(question_id) + return self.template.format(question=temp['question'], constraints=temp['constraints'], format=temp['format'], file_name= str(DABENCH_PATH) + '/da-dev-tables/' + temp['file_name'], level=temp['level'],) + + def get_answer(self, answer_id): + """Retrieve the answer list by its id.""" + return self.answers.get(answer_id, "Answer not found.") + + def eval(self, id, prediction): + """Evaluate the prediction against the true label.""" + true_label = self.get_answer(id)['common_answers'] + # Parse the prediction string into a dictionary of metric-value pairs + pred_dict = {} + for pred in prediction.split(','): + parts = pred.strip().split('[') + metric = parts[0].strip().replace('@', '') + value = float(parts[1].rstrip(']')) + pred_dict[metric] = value + + # Sort the true labels to match the order of predictions + sorted_true_label = sorted(true_label, key=lambda x: x[0]) + + # Compare each prediction with the corresponding true label + correct = True + for metric, true_value in sorted_true_label: + if metric not in pred_dict or abs(pred_dict[metric] - float(true_value)) > 1e-6: + correct = False + break + + return correct + +if __name__ == "__main__": + DA = DABench() + id = 6 + prediction = "@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]" + is_correct = DA.eval(id, prediction) + print(f"Prediction is {'correct' if is_correct else 'incorrect'}.") \ No newline at end of file diff --git a/examples/di/InfiAgent-DABench/README.md b/examples/di/InfiAgent-DABench/README.md new file mode 100644 index 000000000..23eb6504f --- /dev/null +++ b/examples/di/InfiAgent-DABench/README.md @@ -0,0 +1,11 @@ +# InfiAgent-DABench +This example is used to solve the InfiAgent-DABench using Data Interpreter (DI), and obtains 94.93% accuracy using gpt-4o. + +## Dataset-install +``` +git clone https://github.com/InfiAgent/InfiAgent.git +``` +## How to run +``` +python run_InfiAgent-DABench.py --id x +``` \ No newline at end of file diff --git a/examples/di/InfiAgent-DABench/run_InfiAgent-DABench.py b/examples/di/InfiAgent-DABench/run_InfiAgent-DABench.py new file mode 100644 index 000000000..d68845b91 --- /dev/null +++ b/examples/di/InfiAgent-DABench/run_InfiAgent-DABench.py @@ -0,0 +1,17 @@ +import json +import fire + +from metagpt.roles.di.data_interpreter import DataInterpreter +from DABench import DABench + +async def main(id=0): + DA = DABench() + requirement = DA.get_prompt(id) + di = DataInterpreter() + result = await di.run(requirement) + prediction = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]['result'] + is_correct = DA.eval(id, prediction) + print(f"Prediction is {'correct' if is_correct else 'incorrect'}.") + +if __name__ == "__main__": + fire.Fire(main) diff --git a/examples/di/requirements_prompt.py b/examples/di/requirements_prompt.py index 04a0414b1..4b9950611 100644 --- a/examples/di/requirements_prompt.py +++ b/examples/di/requirements_prompt.py @@ -1,3 +1,6 @@ +#InfiAgent-DABench requirements +DABENCH = "You are required to {question} from a CSV file named {file_name}. {constraints}. The output format should be {format}. This task is categorized as {level}." + # ML-Benchmark requirements IRIS_REQ = "Run data analysis on sklearn Iris dataset, include a plot" WINES_RECOGNITION_REQ = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class with 20% as test set, and show prediction accuracy" diff --git a/metagpt/const.py b/metagpt/const.py index f33b46b68..c13cb1dfa 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -43,6 +43,7 @@ DEFAULT_WORKSPACE_ROOT = METAGPT_ROOT / "workspace" EXAMPLE_PATH = METAGPT_ROOT / "examples" EXAMPLE_DATA_PATH = EXAMPLE_PATH / "data" DATA_PATH = METAGPT_ROOT / "data" +DABENCH_PATH = EXAMPLE_PATH / "di/InfiAgent-DABench/InfiAgent/examples/DA-Agent/data" EXAMPLE_BENCHMARK_PATH = EXAMPLE_PATH / "data/rag_bm" TEST_DATA_PATH = METAGPT_ROOT / "tests/data" RESEARCH_PATH = DATA_PATH / "research"