Signed-off-by: kit <101046518@qq.com>
This commit is contained in:
kit 2024-10-10 13:46:10 +08:00
parent 4a508957b0
commit 8b79c6c3a1
6 changed files with 142 additions and 26 deletions

View file

@ -70,10 +70,16 @@ class DABench:
true_label = self.get_answer(id)["common_answers"]
# Parse the prediction string into a dictionary of metric-value pairs
pred_dict = {}
for pred in prediction.split(","):
for pred in prediction.split("@"):
if pred == "":
continue
parts = pred.strip().split("[")
metric = parts[0].strip().replace("@", "")
value = float(parts[1].rstrip("]"))
metric = parts[0].strip().replace(",", "")
value = parts[1].replace(",", "").replace("]", "")
try:
value = float(value)
except:
value = value
pred_dict[metric] = value
# Sort the true labels to match the order of predictions
@ -82,7 +88,16 @@ class DABench:
# Compare each prediction with the corresponding true label
correct = True
for metric, true_value in sorted_true_label:
if metric not in pred_dict or abs(pred_dict[metric] - float(true_value)) > 1e-6:
try:
true_value = float(true_value)
except:
true_value = true_value
if isinstance(true_value, (int, float)) and (
metric not in pred_dict or abs(pred_dict[metric] - true_value) > 1e-6
):
correct = False
break
if isinstance(true_value, str) and (metric not in pred_dict or str(pred_dict[metric]) != str(true_value)):
correct = False
break
@ -97,10 +112,16 @@ class DABench:
pred_dict = {}
# Parse the prediction string into a dictionary of metric-value pairs
for pred in prediction.split(","):
for pred in prediction.split("@"):
if pred == "":
continue
parts = pred.strip().split("[")
metric = parts[0].strip().replace("@", "")
value = float(parts[1].rstrip("]"))
metric = parts[0].strip().replace(",", "")
value = parts[1].replace(",", "").replace("]", "")
try:
value = float(value)
except:
value = value
pred_dict[metric] = value
# Initialize the correctness dictionary with False values
@ -108,11 +129,16 @@ class DABench:
# Check each metric's prediction against the true label
for metric, true_value in true_label:
try:
true_value = float(true_value)
except:
true_value = true_value
if metric in pred_dict:
# Consider the prediction correct if it's within a small tolerance
if abs(pred_dict[metric] - float(true_value)) < 1e-6:
if isinstance(true_value, (int, float)) and abs(pred_dict[metric] - true_value) < 1e-6:
correctness[metric] = True
if isinstance(true_value, str) and str(pred_dict[metric]) == str(true_value):
correctness[metric] = True
return correctness
results = []
@ -134,10 +160,15 @@ class DABench:
if __name__ == "__main__":
DA = DABench()
id = [0, 5, 6]
prediction = [
"@mean_fare[34.89]",
"@correlation_coefficient[0.21]",
"@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]",
]
print(DA.eval_all(id, prediction))
# id = [0, 5, 6]
# prediction = [
# "@mean_fare[34.89]",
# "@correlation_coefficient[0.21]",
# "@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]",
# ]
id = 6
prediction = (
"@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]"
)
print(DA.eval(id, prediction))
print(DA.get_answer(id))

View file

@ -8,5 +8,6 @@ ## Dataset-install
## How to run
```
python run_InfiAgent-DABench_sigle.py --id x # run a task
python run_InfiAgent-DABench_all.py # run all tasks
python run_InfiAgent-DABench_all.py # Run all tasks serially
python run_InfiAgent-DABench.py # Run all tasks in parallel
```

View file

@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
import argparse
import asyncio
import json
from DABench import DABench
from metagpt.roles.di.data_interpreter import DataInterpreter
def init_agent(*args, **kwargs):
return
async def get_prediction(agent_class, requirement):
"""Helper function to get prediction from a new instance of the agent"""
try:
agent = agent_class # Instantiate the agent inside this function to avoid memory conflicts
result = await agent.run(requirement)
prediction_json = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])
prediction = prediction_json[-1]["result"]
return prediction
except Exception as e:
print(f"Error processing requirement: {requirement}. Error: {e}")
return None
async def evaluate_all(agent_class):
"""Evaluate all tasks in DABench using the specified baseline agent"""
DA = DABench()
id_list, predictions = [], []
tasks = []
for key, value in DA.answers.items():
requirement = DA.get_prompt(key)
tasks.append(get_prediction(agent_class, requirement))
id_list.append(key)
# Run all tasks concurrently
predictions = await asyncio.gather(*tasks)
# Filter out any None values in predictions
predictions = [pred for pred in predictions if pred is not None]
print(DA.eval_all(id_list, predictions))
def main():
# Set up argparse to handle command-line arguments
parser = argparse.ArgumentParser(description="Run evaluation with different baselines.")
# Define the command-line argument for the agent name
parser.add_argument(
"--agent_name",
type=str,
default="DataInterpreter",
help="Specify the baseline agent class to use for evaluation.",
)
# Parse the arguments
args = parser.parse_args()
# Manually match the agent name to the class
if args.agent_name == "DataInterpreter":
agent_class = DataInterpreter()
# Add more agents as needed
# elif args.agent_name == "OtherAgent":
# agent_class = OtherAgent
else:
print(f"Agent {args.agent_name} not recognized.")
return
# Run the evaluation with the specified agent class
asyncio.run(evaluate_all(agent_class))
if __name__ == "__main__":
main()

View file

@ -1,6 +1,7 @@
import json
import fire
import pandas as pd
from DABench import DABench
from metagpt.roles.di.data_interpreter import DataInterpreter
@ -9,15 +10,28 @@ from metagpt.roles.di.data_interpreter import DataInterpreter
async def main():
"""Evaluate all"""
DA = DABench()
id_list, predictions = [], []
id_list, predictions, labels, is_true = [], [], [], []
for key, value in DA.answers.items():
requirement = DA.get_prompt(key)
di = DataInterpreter()
result = await di.run(requirement)
prediction = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"]
id_list.append(key)
predictions.append(prediction)
try:
requirement = DA.get_prompt(key)
di = DataInterpreter()
result = await di.run(requirement)
prediction = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"]
id_list.append(key)
is_true.append(str(DA.eval(key, prediction)))
predictions.append(str(prediction))
labels.append(str(DA.get_answer(key)))
except:
id_list.append(key)
is_true.append(str(DA.eval(key, "")))
predictions.append(str(""))
labels.append(str(DA.get_answer(key)))
df = pd.DataFrame({"Label": labels, "Prediction": predictions, "T/F": is_true})
# 将DataFrame写入Excel文件
df.to_excel("output.xlsx", index=False)
print(DA.eval_all(id_list, predictions))
# 将列表转换为pandas DataFrame
if __name__ == "__main__":

View file

@ -6,7 +6,7 @@ from DABench import DABench
from metagpt.roles.di.data_interpreter import DataInterpreter
async def main(id=5):
async def main(id=0):
DA = DABench()
requirement = DA.get_prompt(id)
di = DataInterpreter()

View file

@ -1,5 +1,5 @@
# InfiAgent-DABench requirements
DABENCH = "You are required to {question} from a CSV file named {file_name}. {constraints}. The output format should be {format}. This task is categorized as {level}."
DABENCH = "You are required to solve the problem within a CSV file named {file_name}. **Problem**: {question}, **Constraints**: Ensure that {constraints}, which must be strictly followed throughout the task. The output format should be {format}."
# ML-Benchmark requirements
IRIS_REQ = "Run data analysis on sklearn Iris dataset, include a plot"