Signed-off-by: kit <101046518@qq.com>
This commit is contained in:
kit 2024-10-25 00:47:39 +08:00
parent 0c16c321fe
commit d9ad8fe005
4 changed files with 143 additions and 70 deletions

View file

@ -1,6 +1,10 @@
import asyncio
import json
import re
from pathlib import Path
import nest_asyncio
from examples.di.requirements_prompt import DABENCH
from metagpt.const import DABENCH_PATH
@ -65,43 +69,38 @@ class DABench:
"""Retrieve the answer list by its id."""
return self.answers.get(answer_id, "Answer not found.")
def eval(self, id, prediction):
def eval(self, id: str, result: str) -> bool:
"""Evaluate the prediction against the true label."""
# prediction = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"]
prediction = result
true_label = self.get_answer(id)["common_answers"]
# Parse the prediction string into a dictionary of metric-value pairs
pred_dict = {}
for pred in prediction.split("@"):
if pred == "":
continue
parts = pred.strip().split("[")
metric = parts[0].strip().replace(",", "")
value = parts[1].replace(",", "").replace("]", "")
nest_asyncio.apply()
cleaned_prediction = prediction.replace("{", "").replace("}", "").replace("'", "")
if cleaned_prediction: # Ensure it's not empty
try:
value = float(value)
pred_dict = parse_prediction(cleaned_prediction)
if compare_predictions(pred_dict, true_label):
return (prediction, True)
except:
value = value
pred_dict[metric] = value
print("format errer, using gpt to refomat")
# Sort the true labels to match the order of predictions
sorted_true_label = sorted(true_label, key=lambda x: x[0])
# Compare each prediction with the corresponding true label
correct = True
for metric, true_value in sorted_true_label:
# If the cleaned prediction is not valid, try the async reformat
try:
prediction = asyncio.run(
reformat(self.get_question(id)["question"], self.get_question(id)["format"], result)
)
try:
true_value = float(true_value)
prediction = prediction.split("Answer{{")[1].split("}}")[0].strip()
except:
true_value = true_value
if isinstance(true_value, (int, float)) and (
metric not in pred_dict or abs(pred_dict[metric] - true_value) > 1e-6
):
correct = False
break
if isinstance(true_value, str) and (metric not in pred_dict or str(pred_dict[metric]) != str(true_value)):
correct = False
break
pass
pred_dict = parse_prediction(prediction)
if compare_predictions(pred_dict, true_label):
return (prediction, True)
except Exception as e:
print(f"Error during async reformat: {e}")
# Skip this step if there's an error
return correct
return (prediction, False)
def eval_all(self, id_list, predictions):
"""Evaluate all predictions and calculate accuracy rates."""
@ -109,35 +108,27 @@ class DABench:
def sigle_eval(id, prediction):
"""Evaluate the prediction against the true label for a single question and return a dictionary indicating the correctness of each metric."""
true_label = self.get_answer(id)["common_answers"]
pred_dict = {}
# Parse the prediction string into a dictionary of metric-value pairs
for pred in prediction.split("@"):
if pred == "":
continue
parts = pred.strip().split("[")
metric = parts[0].strip().replace(",", "")
value = parts[1].replace(",", "").replace("]", "")
try:
value = float(value)
except:
value = value
pred_dict[metric] = value
prediction = prediction.replace("{", "").replace("}", "").replace("'", "")
pred_dict = parse_prediction(prediction)
# Initialize the correctness dictionary with False values
correctness = {metric: False for metric, _ in true_label}
# Check each metric's prediction against the true label
for metric, true_value in true_label:
try:
true_value = float(true_value)
except:
true_value = true_value
true_value = true_value.replace(",", "")
if metric in pred_dict:
# Consider the prediction correct if it's within a small tolerance
if isinstance(true_value, (int, float)) and abs(pred_dict[metric] - true_value) < 1e-6:
if (
isinstance(true_value, (int, float))
and isinstance(pred_dict[metric], (int, float))
and abs(pred_dict[metric] - true_value) < 1e-6
):
correctness[metric] = True
if isinstance(true_value, str) and str(pred_dict[metric]) == str(true_value):
if isinstance(true_value, str) and (
metric not in pred_dict or str(pred_dict[metric]).lower() != str(true_value).lower()
):
correctness[metric] = True
return correctness
@ -158,6 +149,91 @@ class DABench:
}
async def ask_and_print(question, system_prompt):
from metagpt.llm import LLM
llm = LLM()
rsp = await llm.aask(question, system_msgs=[system_prompt])
return rsp
# This code is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/reformat.py
async def reformat(question, format, response):
system_prompt = "You are a helpful assistant."
demons = """\Format{{
@shapiro_wilk_statistic[test_statistic]
@shapiro_wilk_p_value[p_value]
where "test_statistic" is a number between 0 and 1 representing the Shapiro-Wilk test statistic. Rounding off the answer to two decimal places.
where "p_value" is a number between 0 and 1 representing the p-value from the Shapiro-Wilk test. Rounding off the answer to four decimal places.
}}
\Answer{{
@shapiro_wilk_statistic[0.56]
@shapiro_wilk_p_value[0.0002]
}}
\Format{{
@total_votes_outliers_num[outlier_num]
where "outlier_num" is an integer representing the number of values considered outliers in the 'total_votes' column.
}}
\Answer{{
@total_votes_outliers[10]
}}
"""
reformat_template = """You should strictly follow the output requirements in the Format part. Here're some examples: {demons}.
Your answer should contain all the \"@answer_name[answer]\" in the order mentioned, each \"answer\" should be in the range of value as required. You need to keep the original numbers and text, just reformat without making any changes.
The format requirements of this question is:
{format}. You need to keep the original numbers and text, just reformat without making any changes. Please give your answer:"""
# res = """[['monthly_avg_windspeed', "{'month_1': 7.17, 'month_2': 6.53, 'month_3': 5.9, 'month_4': 6.69, 'month_5': 5.43, 'month_6': 5.82, 'month_7': 5.13, 'month_8': 5.72, 'month_9': 5.69, 'month_10': 6.57, 'month_11': 5.79, 'month_12': 5.52}"]]}"""
messages = [{"role": "user", "content": question}]
messages.append({"role": "assistant", "content": response})
messages.append({"role": "user", "content": reformat_template.format(demons=demons, format=format)})
rsp = await ask_and_print(messages, system_prompt)
return rsp
def parse_prediction(prediction: str) -> dict:
"""Parse the prediction string into a dictionary of metric-value pairs."""
pred_dict = {}
for pred in prediction.split("@"):
if pred == "":
continue
temp = re.split(r"[\[\]]", pred.strip())
temp = [s.replace(",", "") for s in temp]
parts = [s for s in temp if s]
metric = parts[0].strip().replace(",", "")
value = parts[-1].replace(",", "").replace(":", "")
try:
value = float(value)
except ValueError:
pass # Keep value as string if conversion fails
pred_dict[metric] = value
return pred_dict
def compare_predictions(pred_dict: dict, true_label: list) -> bool:
"""Compare each prediction with the corresponding true label."""
sorted_true_label = sorted(true_label, key=lambda x: x[0])
for metric, true_value in sorted_true_label:
try:
true_value = float(true_value)
except ValueError:
true_value = true_value.replace(",", "")
if isinstance(true_value, (int, float)) and (
metric not in pred_dict or abs(pred_dict[metric] - true_value) > 1e-6
):
return False
if isinstance(true_value, str) and (
metric not in pred_dict or str(pred_dict[metric]).lower() != str(true_value).lower()
):
return False
return True
if __name__ == "__main__":
DA = DABench()
# id = [0, 5, 6]
@ -166,9 +242,6 @@ if __name__ == "__main__":
# "@correlation_coefficient[0.21]",
# "@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]",
# ]
id = 6
prediction = (
"@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]"
)
id = 760
prediction = "@most_missing_station_name[AGE00135039]@most_missing_station_count[0]"
print(DA.eval(id, prediction))
print(DA.get_answer(id))

View file

@ -1,37 +1,37 @@
import json
import fire
import pandas as pd
from DABench import DABench
from metagpt.logs import logger
from metagpt.roles.di.data_interpreter import DataInterpreter
from metagpt.utils.recovery_util import save_history
async def main():
"""Evaluate all"""
DA = DABench()
id_list, predictions, labels, is_true = [], [], [], []
for key, value in DA.answers.items():
id_list.append(key)
labels.append(str(DA.get_answer(key)))
try:
requirement = DA.get_prompt(key)
di = DataInterpreter()
result = await di.run(requirement)
prediction = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"]
id_list.append(key)
is_true.append(str(DA.eval(key, prediction)))
predictions.append(str(prediction))
labels.append(str(DA.get_answer(key)))
logger.info(result)
save_history(role=di)
temp_prediction, temp_istrue = DA.eval(key, str(result))
is_true.append(str(temp_istrue))
predictions.append(str(temp_prediction))
except:
id_list.append(key)
is_true.append(str(DA.eval(key, "")))
predictions.append(str(""))
labels.append(str(DA.get_answer(key)))
df = pd.DataFrame({"Label": labels, "Prediction": predictions, "T/F": is_true})
# 将DataFrame写入Excel文件
df.to_excel("output.xlsx", index=False)
print(DA.eval_all(id_list, predictions))
# 将列表转换为pandas DataFrame
if __name__ == "__main__":

View file

@ -1,9 +1,9 @@
import json
import fire
from DABench import DABench
from metagpt.logs import logger
from metagpt.roles.di.data_interpreter import DataInterpreter
from metagpt.utils.recovery_util import save_history
async def main(id=0):
@ -11,8 +11,9 @@ async def main(id=0):
requirement = DA.get_prompt(id)
di = DataInterpreter()
result = await di.run(requirement)
prediction = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"]
is_correct = DA.eval(id, prediction)
logger.info(result)
save_history(role=di)
_, is_correct = DA.eval(id, str(result))
print(f"Prediction is {'correct' if is_correct else 'incorrect'}.")

View file

@ -1,6 +1,5 @@
# InfiAgent-DABench requirements
DABENCH = "You are required to solve the problem within a CSV file named {file_name}. **Problem**: {question}, **Constraints**: Ensure that {constraints}, which must be strictly followed throughout the task. The output format should be {format}."
DABENCH = "You are required to {question} from a CSV file named {file_name}. **Constraints**: Ensure that {constraints}, which must be strictly followed throughout the task. The output format should be {format}. This task is categorized as {level}."
# ML-Benchmark requirements
IRIS_REQ = "Run data analysis on sklearn Iris dataset, include a plot"
WINES_RECOGNITION_REQ = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class with 20% as test set, and show prediction accuracy"