diff --git a/examples/di/InfiAgent-DABench/DABench.py b/examples/di/InfiAgent-DABench/DABench.py index 0e3b7ccb4..c6018f7c0 100644 --- a/examples/di/InfiAgent-DABench/DABench.py +++ b/examples/di/InfiAgent-DABench/DABench.py @@ -70,10 +70,16 @@ class DABench: true_label = self.get_answer(id)["common_answers"] # Parse the prediction string into a dictionary of metric-value pairs pred_dict = {} - for pred in prediction.split(","): + for pred in prediction.split("@"): + if pred == "": + continue parts = pred.strip().split("[") - metric = parts[0].strip().replace("@", "") - value = float(parts[1].rstrip("]")) + metric = parts[0].strip().replace(",", "") + value = parts[1].replace(",", "").replace("]", "") + try: + value = float(value) + except: + value = value pred_dict[metric] = value # Sort the true labels to match the order of predictions @@ -82,7 +88,16 @@ class DABench: # Compare each prediction with the corresponding true label correct = True for metric, true_value in sorted_true_label: - if metric not in pred_dict or abs(pred_dict[metric] - float(true_value)) > 1e-6: + try: + true_value = float(true_value) + except: + true_value = true_value + if isinstance(true_value, (int, float)) and ( + metric not in pred_dict or abs(pred_dict[metric] - true_value) > 1e-6 + ): + correct = False + break + if isinstance(true_value, str) and (metric not in pred_dict or str(pred_dict[metric]) != str(true_value)): correct = False break @@ -97,10 +112,16 @@ class DABench: pred_dict = {} # Parse the prediction string into a dictionary of metric-value pairs - for pred in prediction.split(","): + for pred in prediction.split("@"): + if pred == "": + continue parts = pred.strip().split("[") - metric = parts[0].strip().replace("@", "") - value = float(parts[1].rstrip("]")) + metric = parts[0].strip().replace(",", "") + value = parts[1].replace(",", "").replace("]", "") + try: + value = float(value) + except: + value = value pred_dict[metric] = value # Initialize the correctness dictionary with False values @@ -108,11 +129,16 @@ class DABench: # Check each metric's prediction against the true label for metric, true_value in true_label: + try: + true_value = float(true_value) + except: + true_value = true_value if metric in pred_dict: # Consider the prediction correct if it's within a small tolerance - if abs(pred_dict[metric] - float(true_value)) < 1e-6: + if isinstance(true_value, (int, float)) and abs(pred_dict[metric] - true_value) < 1e-6: + correctness[metric] = True + if isinstance(true_value, str) and str(pred_dict[metric]) == str(true_value): correctness[metric] = True - return correctness results = [] @@ -134,10 +160,15 @@ class DABench: if __name__ == "__main__": DA = DABench() - id = [0, 5, 6] - prediction = [ - "@mean_fare[34.89]", - "@correlation_coefficient[0.21]", - "@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]", - ] - print(DA.eval_all(id, prediction)) + # id = [0, 5, 6] + # prediction = [ + # "@mean_fare[34.89]", + # "@correlation_coefficient[0.21]", + # "@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]", + # ] + id = 6 + prediction = ( + "@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]" + ) + print(DA.eval(id, prediction)) + print(DA.get_answer(id)) diff --git a/examples/di/InfiAgent-DABench/README.md b/examples/di/InfiAgent-DABench/README.md index 6e60811ea..603bed82b 100644 --- a/examples/di/InfiAgent-DABench/README.md +++ b/examples/di/InfiAgent-DABench/README.md @@ -8,5 +8,6 @@ ## Dataset-install ## How to run ``` python run_InfiAgent-DABench_sigle.py --id x # run a task -python run_InfiAgent-DABench_all.py # run all tasks +python run_InfiAgent-DABench_all.py # Run all tasks serially +python run_InfiAgent-DABench.py # Run all tasks in parallel ``` \ No newline at end of file diff --git a/examples/di/InfiAgent-DABench/run_InfiAgent-DABench.py b/examples/di/InfiAgent-DABench/run_InfiAgent-DABench.py new file mode 100644 index 000000000..175a766ba --- /dev/null +++ b/examples/di/InfiAgent-DABench/run_InfiAgent-DABench.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +import argparse +import asyncio +import json + +from DABench import DABench + +from metagpt.roles.di.data_interpreter import DataInterpreter + + +def init_agent(*args, **kwargs): + return + + +async def get_prediction(agent_class, requirement): + """Helper function to get prediction from a new instance of the agent""" + try: + agent = agent_class # Instantiate the agent inside this function to avoid memory conflicts + result = await agent.run(requirement) + prediction_json = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0]) + prediction = prediction_json[-1]["result"] + return prediction + except Exception as e: + print(f"Error processing requirement: {requirement}. Error: {e}") + return None + + +async def evaluate_all(agent_class): + """Evaluate all tasks in DABench using the specified baseline agent""" + DA = DABench() + id_list, predictions = [], [] + tasks = [] + for key, value in DA.answers.items(): + requirement = DA.get_prompt(key) + tasks.append(get_prediction(agent_class, requirement)) + id_list.append(key) + # Run all tasks concurrently + predictions = await asyncio.gather(*tasks) + # Filter out any None values in predictions + predictions = [pred for pred in predictions if pred is not None] + print(DA.eval_all(id_list, predictions)) + + +def main(): + # Set up argparse to handle command-line arguments + parser = argparse.ArgumentParser(description="Run evaluation with different baselines.") + # Define the command-line argument for the agent name + parser.add_argument( + "--agent_name", + type=str, + default="DataInterpreter", + help="Specify the baseline agent class to use for evaluation.", + ) + # Parse the arguments + args = parser.parse_args() + # Manually match the agent name to the class + if args.agent_name == "DataInterpreter": + agent_class = DataInterpreter() + # Add more agents as needed + # elif args.agent_name == "OtherAgent": + # agent_class = OtherAgent + else: + print(f"Agent {args.agent_name} not recognized.") + return + # Run the evaluation with the specified agent class + asyncio.run(evaluate_all(agent_class)) + + +if __name__ == "__main__": + main() diff --git a/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_all.py b/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_all.py index 852761c6a..95317fbfa 100644 --- a/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_all.py +++ b/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_all.py @@ -1,6 +1,7 @@ import json import fire +import pandas as pd from DABench import DABench from metagpt.roles.di.data_interpreter import DataInterpreter @@ -9,15 +10,28 @@ from metagpt.roles.di.data_interpreter import DataInterpreter async def main(): """Evaluate all""" DA = DABench() - id_list, predictions = [], [] + id_list, predictions, labels, is_true = [], [], [], [] for key, value in DA.answers.items(): - requirement = DA.get_prompt(key) - di = DataInterpreter() - result = await di.run(requirement) - prediction = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"] - id_list.append(key) - predictions.append(prediction) + try: + requirement = DA.get_prompt(key) + di = DataInterpreter() + result = await di.run(requirement) + prediction = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"] + id_list.append(key) + is_true.append(str(DA.eval(key, prediction))) + predictions.append(str(prediction)) + labels.append(str(DA.get_answer(key))) + except: + id_list.append(key) + is_true.append(str(DA.eval(key, ""))) + predictions.append(str("")) + labels.append(str(DA.get_answer(key))) + df = pd.DataFrame({"Label": labels, "Prediction": predictions, "T/F": is_true}) + + # 将DataFrame写入Excel文件 + df.to_excel("output.xlsx", index=False) print(DA.eval_all(id_list, predictions)) + # 将列表转换为pandas DataFrame if __name__ == "__main__": diff --git a/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_sigle.py b/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_sigle.py index 3dc4ad0e9..6db562da8 100644 --- a/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_sigle.py +++ b/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_sigle.py @@ -6,7 +6,7 @@ from DABench import DABench from metagpt.roles.di.data_interpreter import DataInterpreter -async def main(id=5): +async def main(id=0): DA = DABench() requirement = DA.get_prompt(id) di = DataInterpreter() diff --git a/examples/di/requirements_prompt.py b/examples/di/requirements_prompt.py index 1172e1fe5..08eedcadc 100644 --- a/examples/di/requirements_prompt.py +++ b/examples/di/requirements_prompt.py @@ -1,5 +1,5 @@ # InfiAgent-DABench requirements -DABENCH = "You are required to {question} from a CSV file named {file_name}. {constraints}. The output format should be {format}. This task is categorized as {level}." +DABENCH = "You are required to solve the problem within a CSV file named {file_name}. **Problem**: {question}, **Constraints**: Ensure that {constraints}, which must be strictly followed throughout the task. The output format should be {format}." # ML-Benchmark requirements IRIS_REQ = "Run data analysis on sklearn Iris dataset, include a plot"