mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
1
Signed-off-by: kit <101046518@qq.com>
This commit is contained in:
parent
4a508957b0
commit
8b79c6c3a1
6 changed files with 142 additions and 26 deletions
|
|
@ -70,10 +70,16 @@ class DABench:
|
|||
true_label = self.get_answer(id)["common_answers"]
|
||||
# Parse the prediction string into a dictionary of metric-value pairs
|
||||
pred_dict = {}
|
||||
for pred in prediction.split(","):
|
||||
for pred in prediction.split("@"):
|
||||
if pred == "":
|
||||
continue
|
||||
parts = pred.strip().split("[")
|
||||
metric = parts[0].strip().replace("@", "")
|
||||
value = float(parts[1].rstrip("]"))
|
||||
metric = parts[0].strip().replace(",", "")
|
||||
value = parts[1].replace(",", "").replace("]", "")
|
||||
try:
|
||||
value = float(value)
|
||||
except:
|
||||
value = value
|
||||
pred_dict[metric] = value
|
||||
|
||||
# Sort the true labels to match the order of predictions
|
||||
|
|
@ -82,7 +88,16 @@ class DABench:
|
|||
# Compare each prediction with the corresponding true label
|
||||
correct = True
|
||||
for metric, true_value in sorted_true_label:
|
||||
if metric not in pred_dict or abs(pred_dict[metric] - float(true_value)) > 1e-6:
|
||||
try:
|
||||
true_value = float(true_value)
|
||||
except:
|
||||
true_value = true_value
|
||||
if isinstance(true_value, (int, float)) and (
|
||||
metric not in pred_dict or abs(pred_dict[metric] - true_value) > 1e-6
|
||||
):
|
||||
correct = False
|
||||
break
|
||||
if isinstance(true_value, str) and (metric not in pred_dict or str(pred_dict[metric]) != str(true_value)):
|
||||
correct = False
|
||||
break
|
||||
|
||||
|
|
@ -97,10 +112,16 @@ class DABench:
|
|||
pred_dict = {}
|
||||
|
||||
# Parse the prediction string into a dictionary of metric-value pairs
|
||||
for pred in prediction.split(","):
|
||||
for pred in prediction.split("@"):
|
||||
if pred == "":
|
||||
continue
|
||||
parts = pred.strip().split("[")
|
||||
metric = parts[0].strip().replace("@", "")
|
||||
value = float(parts[1].rstrip("]"))
|
||||
metric = parts[0].strip().replace(",", "")
|
||||
value = parts[1].replace(",", "").replace("]", "")
|
||||
try:
|
||||
value = float(value)
|
||||
except:
|
||||
value = value
|
||||
pred_dict[metric] = value
|
||||
|
||||
# Initialize the correctness dictionary with False values
|
||||
|
|
@ -108,11 +129,16 @@ class DABench:
|
|||
|
||||
# Check each metric's prediction against the true label
|
||||
for metric, true_value in true_label:
|
||||
try:
|
||||
true_value = float(true_value)
|
||||
except:
|
||||
true_value = true_value
|
||||
if metric in pred_dict:
|
||||
# Consider the prediction correct if it's within a small tolerance
|
||||
if abs(pred_dict[metric] - float(true_value)) < 1e-6:
|
||||
if isinstance(true_value, (int, float)) and abs(pred_dict[metric] - true_value) < 1e-6:
|
||||
correctness[metric] = True
|
||||
if isinstance(true_value, str) and str(pred_dict[metric]) == str(true_value):
|
||||
correctness[metric] = True
|
||||
|
||||
return correctness
|
||||
|
||||
results = []
|
||||
|
|
@ -134,10 +160,15 @@ class DABench:
|
|||
|
||||
if __name__ == "__main__":
|
||||
DA = DABench()
|
||||
id = [0, 5, 6]
|
||||
prediction = [
|
||||
"@mean_fare[34.89]",
|
||||
"@correlation_coefficient[0.21]",
|
||||
"@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]",
|
||||
]
|
||||
print(DA.eval_all(id, prediction))
|
||||
# id = [0, 5, 6]
|
||||
# prediction = [
|
||||
# "@mean_fare[34.89]",
|
||||
# "@correlation_coefficient[0.21]",
|
||||
# "@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]",
|
||||
# ]
|
||||
id = 6
|
||||
prediction = (
|
||||
"@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]"
|
||||
)
|
||||
print(DA.eval(id, prediction))
|
||||
print(DA.get_answer(id))
|
||||
|
|
|
|||
|
|
@ -8,5 +8,6 @@ ## Dataset-install
|
|||
## How to run
|
||||
```
|
||||
python run_InfiAgent-DABench_sigle.py --id x # run a task
|
||||
python run_InfiAgent-DABench_all.py # run all tasks
|
||||
python run_InfiAgent-DABench_all.py # Run all tasks serially
|
||||
python run_InfiAgent-DABench.py # Run all tasks in parallel
|
||||
```
|
||||
70
examples/di/InfiAgent-DABench/run_InfiAgent-DABench.py
Normal file
70
examples/di/InfiAgent-DABench/run_InfiAgent-DABench.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from DABench import DABench
|
||||
|
||||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
|
||||
|
||||
def init_agent(*args, **kwargs):
|
||||
return
|
||||
|
||||
|
||||
async def get_prediction(agent_class, requirement):
|
||||
"""Helper function to get prediction from a new instance of the agent"""
|
||||
try:
|
||||
agent = agent_class # Instantiate the agent inside this function to avoid memory conflicts
|
||||
result = await agent.run(requirement)
|
||||
prediction_json = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])
|
||||
prediction = prediction_json[-1]["result"]
|
||||
return prediction
|
||||
except Exception as e:
|
||||
print(f"Error processing requirement: {requirement}. Error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def evaluate_all(agent_class):
|
||||
"""Evaluate all tasks in DABench using the specified baseline agent"""
|
||||
DA = DABench()
|
||||
id_list, predictions = [], []
|
||||
tasks = []
|
||||
for key, value in DA.answers.items():
|
||||
requirement = DA.get_prompt(key)
|
||||
tasks.append(get_prediction(agent_class, requirement))
|
||||
id_list.append(key)
|
||||
# Run all tasks concurrently
|
||||
predictions = await asyncio.gather(*tasks)
|
||||
# Filter out any None values in predictions
|
||||
predictions = [pred for pred in predictions if pred is not None]
|
||||
print(DA.eval_all(id_list, predictions))
|
||||
|
||||
|
||||
def main():
|
||||
# Set up argparse to handle command-line arguments
|
||||
parser = argparse.ArgumentParser(description="Run evaluation with different baselines.")
|
||||
# Define the command-line argument for the agent name
|
||||
parser.add_argument(
|
||||
"--agent_name",
|
||||
type=str,
|
||||
default="DataInterpreter",
|
||||
help="Specify the baseline agent class to use for evaluation.",
|
||||
)
|
||||
# Parse the arguments
|
||||
args = parser.parse_args()
|
||||
# Manually match the agent name to the class
|
||||
if args.agent_name == "DataInterpreter":
|
||||
agent_class = DataInterpreter()
|
||||
# Add more agents as needed
|
||||
# elif args.agent_name == "OtherAgent":
|
||||
# agent_class = OtherAgent
|
||||
else:
|
||||
print(f"Agent {args.agent_name} not recognized.")
|
||||
return
|
||||
# Run the evaluation with the specified agent class
|
||||
asyncio.run(evaluate_all(agent_class))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
|
||||
import fire
|
||||
import pandas as pd
|
||||
from DABench import DABench
|
||||
|
||||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
|
|
@ -9,15 +10,28 @@ from metagpt.roles.di.data_interpreter import DataInterpreter
|
|||
async def main():
|
||||
"""Evaluate all"""
|
||||
DA = DABench()
|
||||
id_list, predictions = [], []
|
||||
id_list, predictions, labels, is_true = [], [], [], []
|
||||
for key, value in DA.answers.items():
|
||||
requirement = DA.get_prompt(key)
|
||||
di = DataInterpreter()
|
||||
result = await di.run(requirement)
|
||||
prediction = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"]
|
||||
id_list.append(key)
|
||||
predictions.append(prediction)
|
||||
try:
|
||||
requirement = DA.get_prompt(key)
|
||||
di = DataInterpreter()
|
||||
result = await di.run(requirement)
|
||||
prediction = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"]
|
||||
id_list.append(key)
|
||||
is_true.append(str(DA.eval(key, prediction)))
|
||||
predictions.append(str(prediction))
|
||||
labels.append(str(DA.get_answer(key)))
|
||||
except:
|
||||
id_list.append(key)
|
||||
is_true.append(str(DA.eval(key, "")))
|
||||
predictions.append(str(""))
|
||||
labels.append(str(DA.get_answer(key)))
|
||||
df = pd.DataFrame({"Label": labels, "Prediction": predictions, "T/F": is_true})
|
||||
|
||||
# 将DataFrame写入Excel文件
|
||||
df.to_excel("output.xlsx", index=False)
|
||||
print(DA.eval_all(id_list, predictions))
|
||||
# 将列表转换为pandas DataFrame
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from DABench import DABench
|
|||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
|
||||
|
||||
async def main(id=5):
|
||||
async def main(id=0):
|
||||
DA = DABench()
|
||||
requirement = DA.get_prompt(id)
|
||||
di = DataInterpreter()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# InfiAgent-DABench requirements
|
||||
DABENCH = "You are required to {question} from a CSV file named {file_name}. {constraints}. The output format should be {format}. This task is categorized as {level}."
|
||||
DABENCH = "You are required to solve the problem within a CSV file named {file_name}. **Problem**: {question}, **Constraints**: Ensure that {constraints}, which must be strictly followed throughout the task. The output format should be {format}."
|
||||
|
||||
# ML-Benchmark requirements
|
||||
IRIS_REQ = "Run data analysis on sklearn Iris dataset, include a plot"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue