Signed-off-by: kit <101046518@qq.com>
This commit is contained in:
kit 2024-10-27 13:06:28 +08:00
parent d9ad8fe005
commit ea699caeee
6 changed files with 465 additions and 220 deletions

View file

@ -2,27 +2,71 @@ import asyncio
import json
import re
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union
import nest_asyncio
from examples.di.requirements_prompt import DABENCH
from metagpt.const import DABENCH_PATH
from metagpt.utils.exceptions import handle_exception
# This code is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py
def evaluate_accuracy_by_question(results):
def evaluate_accuracy_by_question(results: dict) -> float:
"""
Calculate the accuracy of results based on complete correctness of each question.
This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py
This function checks whether each result is entirely correct, meaning all sub-questions
within that result are answered correctly. It computes the proportion of correct results
by dividing the number of fully correct results by the total number of results.
Args:
results (dict): A collection of results where each result may contain a 'correctness' field.
Returns:
float: The proportion of correct results, rounded to four decimal places.
Returns 0 if there are no results.
"""
correct = sum("correctness" in result and all(result["correctness"].values()) for result in results)
total = len(results)
return round(correct / total, 4) if total > 0 else 0
def evaluate_accuracy_by_sub_question(results):
def evaluate_accuracy_by_sub_question(results: dict) -> float:
"""
Evaluate the correctness of all sub-questions across the results.
This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py
This function calculates the total number of correct sub-questions and the overall
number of sub-questions present in all results. It returns the ratio of correct
sub-questions to the total number of sub-questions.
Args:
results (dict): A collection of results where each result may contain a 'correctness' field.
Returns:
float: The ratio of correct sub-questions, rounded to four decimal places.
Returns 0 if there are no sub-questions.
"""
correct = sum(sum(result["correctness"].values()) for result in results if "correctness" in result)
total = sum(len(result["correctness"]) for result in results if "correctness" in result)
return round(correct / total, 4) if total > 0 else 0
def evaluate_accuracy_proportional_by_sub_question_adjusted(results):
def evaluate_accuracy_proportional_by_sub_question_adjusted(results: dict) -> float:
"""
Adjust the score based on the number of sub-questions in each result.
This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py
This function calculates a score for each result by considering the number of sub-questions
it contains. Each sub-question is assigned a score of 1 divided by the number of sub-questions.
The total score for each result is computed as the sum of all correct sub-questions multiplied
by the score per sub-question. Finally, it returns the average score across all results.
Args:
results (dict): A collection of results where each result may contain a 'correctness' field.
Returns:
float: The average score across all results, rounded to four decimal places.
Returns 0 if there are no results.
"""
total_score = 0
for result in results:
if "correctness" in result:
@ -33,132 +77,23 @@ def evaluate_accuracy_proportional_by_sub_question_adjusted(results):
return round(total_score / len(results), 4) if results else 0
class DABench:
def __init__(
self,
questions_file=Path(DABENCH_PATH) / "da-dev-questions.jsonl",
answers_file=Path(DABENCH_PATH) / "da-dev-labels.jsonl",
template="",
):
# Read questions from a JSONL file
with open(questions_file, "r") as file:
self.questions = {int(json.loads(line)["id"]): json.loads(line) for line in file}
async def reformat(question: str, format: str, response: str) -> str:
"""
Asynchronously reformats a given response based on specified formatting requirements.
This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/reformat.py
This function constructs a prompt for the LLM (Large Language Model) to reformat
the provided response according to the specified format. It includes a system prompt
to guide the LLM's behavior and a template that outlines the expected output structure.
# Read answers from a JSONL file
with open(answers_file, "r") as file:
self.answers = {int(json.loads(line)["id"]): json.loads(line) for line in file}
Args:
question (str): The original question posed by the user.
format (str): The specific formatting requirements that the response must adhere to.
response (str): The initial response from the LLM that needs to be reformatted.
self.template = template if template else DABENCH
def get_question(self, question_id):
"""Retrieve the question by its id."""
return self.questions.get(question_id, "Question not found.")
def get_prompt(self, question_id):
"""Retrieve the question by its id."""
temp = self.get_question(question_id)
return self.template.format(
question=temp["question"],
constraints=temp["constraints"],
format=temp["format"],
file_name=str(DABENCH_PATH) + "/da-dev-tables/" + temp["file_name"],
level=temp["level"],
)
def get_answer(self, answer_id):
"""Retrieve the answer list by its id."""
return self.answers.get(answer_id, "Answer not found.")
def eval(self, id: str, result: str) -> bool:
"""Evaluate the prediction against the true label."""
# prediction = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"]
prediction = result
true_label = self.get_answer(id)["common_answers"]
nest_asyncio.apply()
cleaned_prediction = prediction.replace("{", "").replace("}", "").replace("'", "")
if cleaned_prediction: # Ensure it's not empty
try:
pred_dict = parse_prediction(cleaned_prediction)
if compare_predictions(pred_dict, true_label):
return (prediction, True)
except:
print("format errer, using gpt to refomat")
# If the cleaned prediction is not valid, try the async reformat
try:
prediction = asyncio.run(
reformat(self.get_question(id)["question"], self.get_question(id)["format"], result)
)
try:
prediction = prediction.split("Answer{{")[1].split("}}")[0].strip()
except:
pass
pred_dict = parse_prediction(prediction)
if compare_predictions(pred_dict, true_label):
return (prediction, True)
except Exception as e:
print(f"Error during async reformat: {e}")
# Skip this step if there's an error
return (prediction, False)
def eval_all(self, id_list, predictions):
"""Evaluate all predictions and calculate accuracy rates."""
def sigle_eval(id, prediction):
"""Evaluate the prediction against the true label for a single question and return a dictionary indicating the correctness of each metric."""
true_label = self.get_answer(id)["common_answers"]
prediction = prediction.replace("{", "").replace("}", "").replace("'", "")
pred_dict = parse_prediction(prediction)
# Initialize the correctness dictionary with False values
correctness = {metric: False for metric, _ in true_label}
# Check each metric's prediction against the true label
for metric, true_value in true_label:
try:
true_value = float(true_value)
except:
true_value = true_value.replace(",", "")
if metric in pred_dict:
# Consider the prediction correct if it's within a small tolerance
if (
isinstance(true_value, (int, float))
and isinstance(pred_dict[metric], (int, float))
and abs(pred_dict[metric] - true_value) < 1e-6
):
correctness[metric] = True
if isinstance(true_value, str) and (
metric not in pred_dict or str(pred_dict[metric]).lower() != str(true_value).lower()
):
correctness[metric] = True
return correctness
results = []
for id, prediction in zip(id_list, predictions):
correct = sigle_eval(id, prediction)
results.append({"id": id, "correctness": correct})
# Calculate the three accuracy rates
accuracy_by_question = evaluate_accuracy_by_question(results)
accuracy_by_sub_question = evaluate_accuracy_by_sub_question(results)
proportional_accuracy_by_sub_question = evaluate_accuracy_proportional_by_sub_question_adjusted(results)
return {
"accuracy_by_question": accuracy_by_question,
"accuracy_by_sub_question": accuracy_by_sub_question,
"proportional_accuracy_by_sub_question": proportional_accuracy_by_sub_question,
}
async def ask_and_print(question, system_prompt):
from metagpt.llm import LLM
llm = LLM()
rsp = await llm.aask(question, system_msgs=[system_prompt])
return rsp
# This code is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/reformat.py
async def reformat(question, format, response):
Returns:
str: The reformatted response generated by the LLM based on the provided question
and formatting requirements.
"""
system_prompt = "You are a helpful assistant."
demons = """\Format{{
@shapiro_wilk_statistic[test_statistic]
@ -183,65 +118,369 @@ async def reformat(question, format, response):
Your answer should contain all the \"@answer_name[answer]\" in the order mentioned, each \"answer\" should be in the range of value as required. You need to keep the original numbers and text, just reformat without making any changes.
The format requirements of this question is:
{format}. You need to keep the original numbers and text, just reformat without making any changes. Please give your answer:"""
# res = """[['monthly_avg_windspeed', "{'month_1': 7.17, 'month_2': 6.53, 'month_3': 5.9, 'month_4': 6.69, 'month_5': 5.43, 'month_6': 5.82, 'month_7': 5.13, 'month_8': 5.72, 'month_9': 5.69, 'month_10': 6.57, 'month_11': 5.79, 'month_12': 5.52}"]]}"""
messages = [{"role": "user", "content": question}]
messages.append({"role": "assistant", "content": response})
messages.append({"role": "user", "content": reformat_template.format(demons=demons, format=format)})
rsp = await ask_and_print(messages, system_prompt)
messages = [
{"role": "user", "content": question},
{"role": "assistant", "content": response},
{"role": "user", "content": reformat_template.format(demons=demons, format=format)},
]
rsp = await ask(messages, system_prompt)
return rsp
def parse_prediction(prediction: str) -> dict:
"""Parse the prediction string into a dictionary of metric-value pairs."""
pred_dict = {}
for pred in prediction.split("@"):
if pred == "":
continue
temp = re.split(r"[\[\]]", pred.strip())
temp = [s.replace(",", "") for s in temp]
parts = [s for s in temp if s]
metric = parts[0].strip().replace(",", "")
value = parts[-1].replace(",", "").replace(":", "")
def load_jsonl(file_path: Union[Path, str]) -> List[Dict[str, Any]]:
"""
Load data from a JSONL file into a list of dictionaries.
try:
value = float(value)
except ValueError:
pass # Keep value as string if conversion fails
Args:
file_path (Union[Path, str]): The path to the JSONL file to be loaded.
pred_dict[metric] = value
return pred_dict
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the data from the JSONL file.
"""
# Convert file_path to Path if it's a string
if isinstance(file_path, str):
file_path = Path(file_path)
data = []
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
data.append(json.loads(line))
return data
def compare_predictions(pred_dict: dict, true_label: list) -> bool:
"""Compare each prediction with the corresponding true label."""
sorted_true_label = sorted(true_label, key=lambda x: x[0])
"""
Compares each prediction against the corresponding true label.
This function checks whether the predicted values match the true values for each
metric. It sorts the true labels to ensure the comparison is made in the correct
order. The function returns True if all predictions are accurate within a small
tolerance for numerical values, or if string values match case-insensitively.
Args:
pred_dict (dict): A dictionary of predicted metrics and their values.
true_label (list): A list of tuples containing true metrics and their values.
Returns:
bool: True if all predictions match the true labels, False otherwise.
"""
sorted_true_label = sorted(true_label, key=lambda x: x[0]) # Sort true labels by metric name
for metric, true_value in sorted_true_label:
try:
true_value = float(true_value)
true_value = float(true_value) # Attempt to convert the true value to float
except ValueError:
true_value = true_value.replace(",", "")
true_value = true_value.replace(",", "") # Clean the true value if conversion fails
# Check if the true value is numeric and compare with the prediction
if isinstance(true_value, (int, float)) and (
metric not in pred_dict or abs(pred_dict[metric] - true_value) > 1e-6
):
return False
return False # Return False if the prediction is inaccurate
# Check if the true value is a string and compare with the prediction
if isinstance(true_value, str) and (
metric not in pred_dict or str(pred_dict[metric]).lower() != str(true_value).lower()
):
return False
return False # Return False if the string prediction does not match
return True
return True # Return True if all predictions are accurate
async def ask(question: str, system_prompt: str) -> str:
"""
Asynchronously sends a question to the LLM (Large Language Model) and retrieves the response.
This function initializes an instance of the LLM and uses it to ask a question
along with a system prompt. The response from the LLM is awaited and returned.
Args:
question (str): The question to be asked to the LLM.
system_prompt (str): A prompt that provides context or instructions to the LLM.
Returns:
str: The response from the LLM based on the provided question and system prompt.
"""
from metagpt.llm import LLM # Importing the LLM class from the metagpt module
llm = LLM() # Create an instance of the LLM
rsp = await llm.aask(question, system_msgs=[system_prompt]) # Await the response from the LLM
return rsp # Return the response
def parse_prediction(prediction: str) -> dict:
"""
Parses a prediction string into a dictionary of metric-value pairs.
This function takes a formatted string containing metrics and their corresponding
values, separated by the "@" symbol. Each metric may be enclosed in brackets and
may include commas. The function processes the input to extract and clean the
metrics and their values, returning them in a structured dictionary format.
Args:
prediction (str): A string representation of metrics and their values.
Returns:
dict: A dictionary where each key is a metric name and each value is the
corresponding value, either as a float or a string.
"""
pred_dict = {}
for pred in prediction.split("@"):
if pred == "":
continue # Skip any empty segments resulting from the split
temp = re.split(r"[\[\]]", pred.strip()) # Split the string by brackets
temp = [s.replace(",", "") for s in temp] # Remove commas from the segments
parts = [s for s in temp if s] # Filter out any empty strings
metric = parts[0].strip().replace(",", "") # Extract and clean the metric name
value = parts[-1].replace(",", "").replace(":", "") # Extract and clean the value
try:
value = float(value) # Attempt to convert the value to a float
except ValueError:
pass # If conversion fails, retain the value as a string
pred_dict[metric] = value # Store the metric-value pair in the dictionary
return pred_dict
class DABench:
def __init__(
self,
questions_file: Path = Path(DABENCH_PATH) / "da-dev-questions.jsonl",
answers_file: Path = Path(DABENCH_PATH) / "da-dev-labels.jsonl",
template: str = "",
):
"""
Initializes the DABench instance with questions and answers.
This constructor loads questions and answers from specified JSONL files.
It also sets a template for formatting prompts. If no template is provided,
a default template is used.
Args:
questions_file (Path): The path to the JSONL file containing questions.
answers_file (Path): The path to the JSONL file containing answers.
template (str): A string template for formatting prompts.
"""
self.questions = {
int(line["id"]): line for line in load_jsonl(questions_file)
} # Load questions from the specified file
self.answers = {
int(line["id"]): line for line in load_jsonl(answers_file)
} # Load answers from the specified file
self.template = template if template else DABENCH # Set the template, defaulting if necessary
def get_question(self, question_id: str) -> dict:
"""
Retrieve the question associated with the given ID.
This method looks up a question by its unique identifier. If the question
is found, it returns the question data; otherwise, it returns a message
indicating that the question was not found.
Args:
question_id (str): The unique identifier for the question.
Returns:
dict: The question data if found, otherwise a "Question not found." message.
"""
return self.questions.get(question_id, "Question not found.") # Return the question or an error message
def generate_formatted_prompt(self, question_id: str) -> str:
"""
Generate a formatted prompt for the specified question ID.
This method retrieves the question data and formats it using the specified
template. The formatted prompt includes the question, constraints, format,
file name, and level, allowing for a structured output.
Args:
question_id (str): The unique identifier for the question.
Returns:
str: A formatted prompt string based on the question data.
"""
temp = self.get_question(question_id) # Retrieve the question data
return self.template.format(
question=temp["question"],
constraints=temp["constraints"],
format=temp["format"],
file_name=str(DABENCH_PATH) + "/da-dev-tables/" + temp["file_name"],
level=temp["level"],
) # Format and return the prompt
def get_answer(self, answer_id: str) -> list:
"""
Retrieve the answer list associated with the given ID.
This method looks up an answer by its unique identifier. If the answer
is found, it returns the answer data; otherwise, it returns a message
indicating that the answer was not found.
Args:
answer_id (str): The unique identifier for the answer.
Returns:
list: The answer data if found, otherwise an "Answer not found." message.
"""
return self.answers.get(answer_id, "Answer not found.") # Return the answer or an error message
@handle_exception(exception_msg="Error parsing cleaned prediction", default_return=(None, False))
def parse_cleaned_prediction(self, cleaned_prediction: str, true_label: Any) -> Tuple[str, bool]:
"""
Parse the cleaned prediction and compare it with the true label.
Args:
cleaned_prediction (str): The cleaned prediction string.
true_label (Any): The true label to compare against.
Returns:
Tuple[str, bool]: A tuple containing the cleaned prediction and a boolean indicating
whether it matches the true label.
"""
if cleaned_prediction: # Ensure the cleaned prediction is not empty
pred_dict = parse_prediction(cleaned_prediction) # Parse the prediction
if pred_dict is not None and compare_predictions(pred_dict, true_label):
return cleaned_prediction, True # Return if the prediction matches the true label
return cleaned_prediction, False # Return the cleaned prediction with a False match
@handle_exception(exception_msg="Error during async reformat", default_return=(None, False))
def async_reformat_prediction(self, id: str, result: str) -> str:
"""
Reformat the prediction asynchronously and extract the answer.
Args:
id (str): The identifier for the question.
result (str): The original prediction result.
Returns:
str: The reformatted prediction or the original prediction if extraction fails.
"""
question = self.get_question(id)["question"] # Retrieve the question based on the ID
question_format = self.get_question(id)["format"] # Get the format of the question
prediction = asyncio.run(reformat(question, question_format, result)) # Asynchronously reformat the prediction
# Attempt to extract the answer from the reformatted prediction
answer_part = prediction.split("Answer{{") if "Answer{{" in prediction else []
if len(answer_part) > 1:
return answer_part[1].split("}}")[0].strip() # Return the extracted answer
return prediction # If extraction fails, return the original prediction
def eval(self, id: str, result: str) -> Tuple[str, bool]:
"""
Evaluate the prediction against the true label.
Args:
id (str): The identifier for the question.
result (str): The original prediction result.
Returns:
Tuple[str, bool]: A tuple containing the final prediction and a boolean indicating
whether it matches the true label.
"""
true_label = self.get_answer(id)["common_answers"] # Retrieve the true label for comparison
nest_asyncio.apply() # Apply nested asyncio to allow for async calls
result = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"].strip()
cleaned_prediction = result.replace("{", "").replace("}", "").replace("'", "") # Clean the prediction string
# Use the decorated function to handle exceptions while parsing the cleaned prediction
parsed_result = self.parse_cleaned_prediction(cleaned_prediction, true_label)
if parsed_result[1]: # If the parsed prediction is valid
return parsed_result # Return the valid prediction
# If the cleaned prediction is not valid, attempt to asynchronously reformat it
prediction = self.async_reformat_prediction(id, result)
pred_dict = parse_prediction(prediction) # Parse the reformatted prediction
if pred_dict is not None and compare_predictions(pred_dict, true_label):
return prediction, True # Return if the reformatted prediction matches the true label
return prediction, False # Return the final prediction with a False match
@handle_exception(exception_msg="Error evaluating single prediction", default_return={})
def single_eval(self, id: str, prediction: str) -> dict:
"""
Evaluate the prediction against the true label for a single question.
just using in eval_all
Args:
id (str): The identifier for the question.
prediction (str): The prediction string to evaluate.
Returns:
dict: A dictionary indicating the correctness of each metric.
"""
true_label = self.get_answer(id)["common_answers"] # Retrieve the true label for the question
prediction = prediction.replace("{", "").replace("}", "").replace("'", "") # Clean the prediction string
pred_dict = parse_prediction(prediction) # Parse the prediction into a dictionary
# Initialize the correctness dictionary with False values for each metric
correctness = {metric: False for metric, _ in true_label}
# Check each metric's prediction against the true label
for metric, true_value in true_label:
try:
true_value = float(true_value) # Attempt to convert the true value to float
except ValueError:
true_value = true_value.replace(",", "") # Handle non-numeric values
if metric in pred_dict:
# Consider the prediction correct if it's within a small tolerance
if (
isinstance(true_value, (int, float))
and isinstance(pred_dict[metric], (int, float))
and abs(pred_dict[metric] - true_value) < 1e-6
):
correctness[metric] = True # Mark as correct if within tolerance
if isinstance(true_value, str) and (
metric not in pred_dict or str(pred_dict[metric]).lower() != str(true_value).lower()
):
correctness[metric] = True # Mark as correct for string comparison
return correctness # Return the correctness dictionary
def eval_all(self, id_list: list, predictions: list) -> dict:
"""
Evaluate all predictions and calculate accuracy rates.
Args:
id_list (list): A list of question identifiers.
predictions (list): A list of prediction strings corresponding to the questions.
Returns:
dict: A dictionary containing accuracy rates by question and sub-question.
"""
results = [] # Initialize a list to store results for each question
# Evaluate each prediction against its corresponding question ID
for id, prediction in zip(id_list, predictions):
correct = self.single_eval(id, prediction) # Evaluate the single prediction
results.append({"id": id, "correctness": correct}) # Append the result to the list
# Calculate the three accuracy rates based on the results
accuracy_by_question = evaluate_accuracy_by_question(results)
accuracy_by_sub_question = evaluate_accuracy_by_sub_question(results)
proportional_accuracy_by_sub_question = evaluate_accuracy_proportional_by_sub_question_adjusted(results)
return {
"accuracy_by_question": accuracy_by_question,
"accuracy_by_sub_question": accuracy_by_sub_question,
"proportional_accuracy_by_sub_question": proportional_accuracy_by_sub_question,
}
if __name__ == "__main__":
DA = DABench()
# id = [0, 5, 6]
# prediction = [
# "@mean_fare[34.89]",
# "@correlation_coefficient[0.21]",
# "@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]",
# ]
id = 760
prediction = "@most_missing_station_name[AGE00135039]@most_missing_station_count[0]"
id = 0
prediction = "@mean_fare[34.65]"
print(DA.eval(id, prediction))
ids = [0, 5, 6]
predictions = [
"@mean_fare[34.89]",
"@correlation_coefficient[0.21]",
"@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]",
]
print(DA.eval_all(ids, predictions))

View file

@ -1,13 +1,15 @@
# InfiAgent-DABench
This example is used to solve the InfiAgent-DABench using Data Interpreter (DI), and obtains 94.93% accuracy using gpt-4o.
## Dataset-install
## Dataset
```
cd /examples/di/InfiAgent-DABench
git clone https://github.com/InfiAgent/InfiAgent.git
mv InfiAgent/examples/DA-Agent/data ./
```
## How to run
```
python run_InfiAgent-DABench_sigle.py --id x # run a task
python run_InfiAgent-DABench_sigle.py --id x # run a task, x represents the id of the question you want to test
python run_InfiAgent-DABench_all.py # Run all tasks serially
python run_InfiAgent-DABench.py # Run all tasks in parallel
python run_InfiAgent-DABench.py --k x # Run all tasks in parallel, x represents the number of parallel tasks at a time
```

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
import argparse
import asyncio
import json
@ -8,62 +6,70 @@ from DABench import DABench
from metagpt.roles.di.data_interpreter import DataInterpreter
def init_agent(*args, **kwargs):
return
async def get_prediction(agent, requirement):
"""Helper function to obtain a prediction from a new instance of the agent.
This function runs the agent with the provided requirement and extracts the prediction
from the result. If an error occurs during processing, it logs the error and returns None.
async def get_prediction(agent_class, requirement):
"""Helper function to get prediction from a new instance of the agent"""
Args:
agent: The agent instance used to generate predictions.
requirement: The input requirement for which the prediction is to be made.
Returns:
The predicted result if successful, otherwise None.
"""
try:
agent = agent_class # Instantiate the agent inside this function to avoid memory conflicts
# Run the agent with the given requirement and await the result
result = await agent.run(requirement)
# Parse the result to extract the prediction from the JSON response
prediction_json = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])
prediction = prediction_json[-1]["result"]
return prediction
prediction = prediction_json[-1]["result"] # Extract the last result from the parsed JSON
return prediction # Return the extracted prediction
except Exception as e:
# Log an error message if an exception occurs during processing
print(f"Error processing requirement: {requirement}. Error: {e}")
return None
return None # Return None in case of an error
async def evaluate_all(agent_class):
"""Evaluate all tasks in DABench using the specified baseline agent"""
DA = DABench()
id_list, predictions = [], []
tasks = []
async def evaluate_all(agent, k):
"""Evaluate all tasks in DABench using the specified baseline agent.
Tasks are divided into groups of size k and processed in parallel.
Args:
agent: The baseline agent used for making predictions.
k (int): The number of tasks to process in each group concurrently.
"""
DA = DABench() # Create an instance of DABench to access its methods and data
id_list, predictions = [], [] # Initialize lists to store IDs and predictions
tasks = [] # Initialize a list to hold the tasks
# Iterate over the answers in DABench to generate tasks
for key, value in DA.answers.items():
requirement = DA.get_prompt(key)
tasks.append(get_prediction(agent_class, requirement))
id_list.append(key)
# Run all tasks concurrently
predictions = await asyncio.gather(*tasks)
# Filter out any None values in predictions
predictions = [pred for pred in predictions if pred is not None]
requirement = DA.generate_formatted_prompt(key) # Generate a formatted prompt for the current key
tasks.append(get_prediction(agent, requirement)) # Append the prediction task to the tasks list
id_list.append(key) # Append the current key to the ID list
# Process tasks in groups of size k and execute them concurrently
for i in range(0, len(tasks), k):
# Get the current group of tasks
current_group = tasks[i : i + k]
# Execute the current group of tasks in parallel
group_predictions = await asyncio.gather(*current_group)
# Filter out any None values from the predictions and extend the predictions list
predictions.extend(pred for pred in group_predictions if pred is not None)
# Evaluate the results using all valid predictions and print the evaluation
print(DA.eval_all(id_list, predictions))
def main():
# Set up argparse to handle command-line arguments
parser = argparse.ArgumentParser(description="Run evaluation with different baselines.")
# Define the command-line argument for the agent name
parser.add_argument(
"--agent_name",
type=str,
default="DataInterpreter",
help="Specify the baseline agent class to use for evaluation.",
)
# Parse the arguments
args = parser.parse_args()
# Manually match the agent name to the class
if args.agent_name == "DataInterpreter":
agent_class = DataInterpreter()
# Add more agents as needed
# elif args.agent_name == "OtherAgent":
# agent_class = OtherAgent
else:
print(f"Agent {args.agent_name} not recognized.")
return
# Run the evaluation with the specified agent class
asyncio.run(evaluate_all(agent_class))
def main(k=5):
"""Main function to run the evaluation process."""
agent = DataInterpreter() # Create an instance of the DataInterpreter agent
asyncio.run(evaluate_all(agent, k)) # Run the evaluate_all function asynchronously
if __name__ == "__main__":

View file

@ -11,12 +11,11 @@ async def main():
"""Evaluate all"""
DA = DABench()
id_list, predictions, labels, is_true = [], [], [], []
for key, value in DA.answers.items():
id_list.append(key)
labels.append(str(DA.get_answer(key)))
try:
requirement = DA.get_prompt(key)
requirement = DA.generate_formatted_prompt(key)
di = DataInterpreter()
result = await di.run(requirement)
logger.info(result)
@ -24,13 +23,11 @@ async def main():
temp_prediction, temp_istrue = DA.eval(key, str(result))
is_true.append(str(temp_istrue))
predictions.append(str(temp_prediction))
except:
is_true.append(str(DA.eval(key, "")))
predictions.append(str(""))
df = pd.DataFrame({"Label": labels, "Prediction": predictions, "T/F": is_true})
df.to_excel("output.xlsx", index=False)
df.to_excel("DABench_output.xlsx", index=False)
print(DA.eval_all(id_list, predictions))

View file

@ -7,8 +7,9 @@ from metagpt.utils.recovery_util import save_history
async def main(id=0):
"""Evaluate one task"""
DA = DABench()
requirement = DA.get_prompt(id)
requirement = DA.generate_formatted_prompt(id)
di = DataInterpreter()
result = await di.run(requirement)
logger.info(result)