mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-05 14:55:18 +02:00
1
Signed-off-by: kit <101046518@qq.com>
This commit is contained in:
parent
0c16c321fe
commit
d9ad8fe005
4 changed files with 143 additions and 70 deletions
|
|
@ -1,6 +1,10 @@
|
|||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import nest_asyncio
|
||||
|
||||
from examples.di.requirements_prompt import DABENCH
|
||||
from metagpt.const import DABENCH_PATH
|
||||
|
||||
|
|
@ -65,43 +69,38 @@ class DABench:
|
|||
"""Retrieve the answer list by its id."""
|
||||
return self.answers.get(answer_id, "Answer not found.")
|
||||
|
||||
def eval(self, id, prediction):
|
||||
def eval(self, id: str, result: str) -> bool:
|
||||
"""Evaluate the prediction against the true label."""
|
||||
# prediction = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"]
|
||||
prediction = result
|
||||
true_label = self.get_answer(id)["common_answers"]
|
||||
# Parse the prediction string into a dictionary of metric-value pairs
|
||||
pred_dict = {}
|
||||
for pred in prediction.split("@"):
|
||||
if pred == "":
|
||||
continue
|
||||
parts = pred.strip().split("[")
|
||||
metric = parts[0].strip().replace(",", "")
|
||||
value = parts[1].replace(",", "").replace("]", "")
|
||||
nest_asyncio.apply()
|
||||
cleaned_prediction = prediction.replace("{", "").replace("}", "").replace("'", "")
|
||||
if cleaned_prediction: # Ensure it's not empty
|
||||
try:
|
||||
value = float(value)
|
||||
pred_dict = parse_prediction(cleaned_prediction)
|
||||
if compare_predictions(pred_dict, true_label):
|
||||
return (prediction, True)
|
||||
except:
|
||||
value = value
|
||||
pred_dict[metric] = value
|
||||
print("format errer, using gpt to refomat")
|
||||
|
||||
# Sort the true labels to match the order of predictions
|
||||
sorted_true_label = sorted(true_label, key=lambda x: x[0])
|
||||
|
||||
# Compare each prediction with the corresponding true label
|
||||
correct = True
|
||||
for metric, true_value in sorted_true_label:
|
||||
# If the cleaned prediction is not valid, try the async reformat
|
||||
try:
|
||||
prediction = asyncio.run(
|
||||
reformat(self.get_question(id)["question"], self.get_question(id)["format"], result)
|
||||
)
|
||||
try:
|
||||
true_value = float(true_value)
|
||||
prediction = prediction.split("Answer{{")[1].split("}}")[0].strip()
|
||||
except:
|
||||
true_value = true_value
|
||||
if isinstance(true_value, (int, float)) and (
|
||||
metric not in pred_dict or abs(pred_dict[metric] - true_value) > 1e-6
|
||||
):
|
||||
correct = False
|
||||
break
|
||||
if isinstance(true_value, str) and (metric not in pred_dict or str(pred_dict[metric]) != str(true_value)):
|
||||
correct = False
|
||||
break
|
||||
pass
|
||||
pred_dict = parse_prediction(prediction)
|
||||
if compare_predictions(pred_dict, true_label):
|
||||
return (prediction, True)
|
||||
except Exception as e:
|
||||
print(f"Error during async reformat: {e}")
|
||||
# Skip this step if there's an error
|
||||
|
||||
return correct
|
||||
return (prediction, False)
|
||||
|
||||
def eval_all(self, id_list, predictions):
|
||||
"""Evaluate all predictions and calculate accuracy rates."""
|
||||
|
|
@ -109,35 +108,27 @@ class DABench:
|
|||
def sigle_eval(id, prediction):
|
||||
"""Evaluate the prediction against the true label for a single question and return a dictionary indicating the correctness of each metric."""
|
||||
true_label = self.get_answer(id)["common_answers"]
|
||||
pred_dict = {}
|
||||
|
||||
# Parse the prediction string into a dictionary of metric-value pairs
|
||||
for pred in prediction.split("@"):
|
||||
if pred == "":
|
||||
continue
|
||||
parts = pred.strip().split("[")
|
||||
metric = parts[0].strip().replace(",", "")
|
||||
value = parts[1].replace(",", "").replace("]", "")
|
||||
try:
|
||||
value = float(value)
|
||||
except:
|
||||
value = value
|
||||
pred_dict[metric] = value
|
||||
|
||||
prediction = prediction.replace("{", "").replace("}", "").replace("'", "")
|
||||
pred_dict = parse_prediction(prediction)
|
||||
# Initialize the correctness dictionary with False values
|
||||
correctness = {metric: False for metric, _ in true_label}
|
||||
|
||||
# Check each metric's prediction against the true label
|
||||
for metric, true_value in true_label:
|
||||
try:
|
||||
true_value = float(true_value)
|
||||
except:
|
||||
true_value = true_value
|
||||
true_value = true_value.replace(",", "")
|
||||
if metric in pred_dict:
|
||||
# Consider the prediction correct if it's within a small tolerance
|
||||
if isinstance(true_value, (int, float)) and abs(pred_dict[metric] - true_value) < 1e-6:
|
||||
if (
|
||||
isinstance(true_value, (int, float))
|
||||
and isinstance(pred_dict[metric], (int, float))
|
||||
and abs(pred_dict[metric] - true_value) < 1e-6
|
||||
):
|
||||
correctness[metric] = True
|
||||
if isinstance(true_value, str) and str(pred_dict[metric]) == str(true_value):
|
||||
if isinstance(true_value, str) and (
|
||||
metric not in pred_dict or str(pred_dict[metric]).lower() != str(true_value).lower()
|
||||
):
|
||||
correctness[metric] = True
|
||||
return correctness
|
||||
|
||||
|
|
@ -158,6 +149,91 @@ class DABench:
|
|||
}
|
||||
|
||||
|
||||
async def ask_and_print(question, system_prompt):
|
||||
from metagpt.llm import LLM
|
||||
|
||||
llm = LLM()
|
||||
rsp = await llm.aask(question, system_msgs=[system_prompt])
|
||||
return rsp
|
||||
|
||||
|
||||
# This code is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/reformat.py
|
||||
async def reformat(question, format, response):
|
||||
system_prompt = "You are a helpful assistant."
|
||||
demons = """\Format{{
|
||||
@shapiro_wilk_statistic[test_statistic]
|
||||
@shapiro_wilk_p_value[p_value]
|
||||
where "test_statistic" is a number between 0 and 1 representing the Shapiro-Wilk test statistic. Rounding off the answer to two decimal places.
|
||||
where "p_value" is a number between 0 and 1 representing the p-value from the Shapiro-Wilk test. Rounding off the answer to four decimal places.
|
||||
}}
|
||||
\Answer{{
|
||||
@shapiro_wilk_statistic[0.56]
|
||||
@shapiro_wilk_p_value[0.0002]
|
||||
}}
|
||||
|
||||
\Format{{
|
||||
@total_votes_outliers_num[outlier_num]
|
||||
where "outlier_num" is an integer representing the number of values considered outliers in the 'total_votes' column.
|
||||
}}
|
||||
\Answer{{
|
||||
@total_votes_outliers[10]
|
||||
}}
|
||||
"""
|
||||
reformat_template = """You should strictly follow the output requirements in the Format part. Here're some examples: {demons}.
|
||||
Your answer should contain all the \"@answer_name[answer]\" in the order mentioned, each \"answer\" should be in the range of value as required. You need to keep the original numbers and text, just reformat without making any changes.
|
||||
The format requirements of this question is:
|
||||
{format}. You need to keep the original numbers and text, just reformat without making any changes. Please give your answer:"""
|
||||
# res = """[['monthly_avg_windspeed', "{'month_1': 7.17, 'month_2': 6.53, 'month_3': 5.9, 'month_4': 6.69, 'month_5': 5.43, 'month_6': 5.82, 'month_7': 5.13, 'month_8': 5.72, 'month_9': 5.69, 'month_10': 6.57, 'month_11': 5.79, 'month_12': 5.52}"]]}"""
|
||||
messages = [{"role": "user", "content": question}]
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
messages.append({"role": "user", "content": reformat_template.format(demons=demons, format=format)})
|
||||
rsp = await ask_and_print(messages, system_prompt)
|
||||
return rsp
|
||||
|
||||
|
||||
def parse_prediction(prediction: str) -> dict:
|
||||
"""Parse the prediction string into a dictionary of metric-value pairs."""
|
||||
pred_dict = {}
|
||||
for pred in prediction.split("@"):
|
||||
if pred == "":
|
||||
continue
|
||||
temp = re.split(r"[\[\]]", pred.strip())
|
||||
temp = [s.replace(",", "") for s in temp]
|
||||
parts = [s for s in temp if s]
|
||||
metric = parts[0].strip().replace(",", "")
|
||||
value = parts[-1].replace(",", "").replace(":", "")
|
||||
|
||||
try:
|
||||
value = float(value)
|
||||
except ValueError:
|
||||
pass # Keep value as string if conversion fails
|
||||
|
||||
pred_dict[metric] = value
|
||||
return pred_dict
|
||||
|
||||
|
||||
def compare_predictions(pred_dict: dict, true_label: list) -> bool:
|
||||
"""Compare each prediction with the corresponding true label."""
|
||||
sorted_true_label = sorted(true_label, key=lambda x: x[0])
|
||||
|
||||
for metric, true_value in sorted_true_label:
|
||||
try:
|
||||
true_value = float(true_value)
|
||||
except ValueError:
|
||||
true_value = true_value.replace(",", "")
|
||||
|
||||
if isinstance(true_value, (int, float)) and (
|
||||
metric not in pred_dict or abs(pred_dict[metric] - true_value) > 1e-6
|
||||
):
|
||||
return False
|
||||
if isinstance(true_value, str) and (
|
||||
metric not in pred_dict or str(pred_dict[metric]).lower() != str(true_value).lower()
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
DA = DABench()
|
||||
# id = [0, 5, 6]
|
||||
|
|
@ -166,9 +242,6 @@ if __name__ == "__main__":
|
|||
# "@correlation_coefficient[0.21]",
|
||||
# "@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]",
|
||||
# ]
|
||||
id = 6
|
||||
prediction = (
|
||||
"@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]"
|
||||
)
|
||||
id = 760
|
||||
prediction = "@most_missing_station_name[AGE00135039]@most_missing_station_count[0]"
|
||||
print(DA.eval(id, prediction))
|
||||
print(DA.get_answer(id))
|
||||
|
|
|
|||
|
|
@ -1,37 +1,37 @@
|
|||
import json
|
||||
|
||||
import fire
|
||||
import pandas as pd
|
||||
from DABench import DABench
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
from metagpt.utils.recovery_util import save_history
|
||||
|
||||
|
||||
async def main():
|
||||
"""Evaluate all"""
|
||||
DA = DABench()
|
||||
id_list, predictions, labels, is_true = [], [], [], []
|
||||
|
||||
for key, value in DA.answers.items():
|
||||
id_list.append(key)
|
||||
labels.append(str(DA.get_answer(key)))
|
||||
try:
|
||||
requirement = DA.get_prompt(key)
|
||||
di = DataInterpreter()
|
||||
result = await di.run(requirement)
|
||||
prediction = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"]
|
||||
id_list.append(key)
|
||||
is_true.append(str(DA.eval(key, prediction)))
|
||||
predictions.append(str(prediction))
|
||||
labels.append(str(DA.get_answer(key)))
|
||||
logger.info(result)
|
||||
save_history(role=di)
|
||||
temp_prediction, temp_istrue = DA.eval(key, str(result))
|
||||
is_true.append(str(temp_istrue))
|
||||
predictions.append(str(temp_prediction))
|
||||
|
||||
except:
|
||||
id_list.append(key)
|
||||
is_true.append(str(DA.eval(key, "")))
|
||||
predictions.append(str(""))
|
||||
labels.append(str(DA.get_answer(key)))
|
||||
df = pd.DataFrame({"Label": labels, "Prediction": predictions, "T/F": is_true})
|
||||
|
||||
# 将DataFrame写入Excel文件
|
||||
df.to_excel("output.xlsx", index=False)
|
||||
print(DA.eval_all(id_list, predictions))
|
||||
# 将列表转换为pandas DataFrame
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import json
|
||||
|
||||
import fire
|
||||
from DABench import DABench
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
from metagpt.utils.recovery_util import save_history
|
||||
|
||||
|
||||
async def main(id=0):
|
||||
|
|
@ -11,8 +11,9 @@ async def main(id=0):
|
|||
requirement = DA.get_prompt(id)
|
||||
di = DataInterpreter()
|
||||
result = await di.run(requirement)
|
||||
prediction = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"]
|
||||
is_correct = DA.eval(id, prediction)
|
||||
logger.info(result)
|
||||
save_history(role=di)
|
||||
_, is_correct = DA.eval(id, str(result))
|
||||
print(f"Prediction is {'correct' if is_correct else 'incorrect'}.")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# InfiAgent-DABench requirements
|
||||
DABENCH = "You are required to solve the problem within a CSV file named {file_name}. **Problem**: {question}, **Constraints**: Ensure that {constraints}, which must be strictly followed throughout the task. The output format should be {format}."
|
||||
|
||||
DABENCH = "You are required to {question} from a CSV file named {file_name}. **Constraints**: Ensure that {constraints}, which must be strictly followed throughout the task. The output format should be {format}. This task is categorized as {level}."
|
||||
# ML-Benchmark requirements
|
||||
IRIS_REQ = "Run data analysis on sklearn Iris dataset, include a plot"
|
||||
WINES_RECOGNITION_REQ = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class with 20% as test set, and show prediction accuracy"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue