mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-30 14:35:17 +02:00
1
Signed-off-by: kit <101046518@qq.com>
This commit is contained in:
parent
d9ad8fe005
commit
ea699caeee
6 changed files with 465 additions and 220 deletions
|
|
@ -2,27 +2,71 @@ import asyncio
|
|||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import nest_asyncio
|
||||
|
||||
from examples.di.requirements_prompt import DABENCH
|
||||
from metagpt.const import DABENCH_PATH
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
|
||||
# This code is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py
|
||||
def evaluate_accuracy_by_question(results):
|
||||
def evaluate_accuracy_by_question(results: dict) -> float:
|
||||
"""
|
||||
Calculate the accuracy of results based on complete correctness of each question.
|
||||
This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py
|
||||
This function checks whether each result is entirely correct, meaning all sub-questions
|
||||
within that result are answered correctly. It computes the proportion of correct results
|
||||
by dividing the number of fully correct results by the total number of results.
|
||||
|
||||
Args:
|
||||
results (dict): A collection of results where each result may contain a 'correctness' field.
|
||||
|
||||
Returns:
|
||||
float: The proportion of correct results, rounded to four decimal places.
|
||||
Returns 0 if there are no results.
|
||||
"""
|
||||
correct = sum("correctness" in result and all(result["correctness"].values()) for result in results)
|
||||
total = len(results)
|
||||
return round(correct / total, 4) if total > 0 else 0
|
||||
|
||||
|
||||
def evaluate_accuracy_by_sub_question(results):
|
||||
def evaluate_accuracy_by_sub_question(results: dict) -> float:
|
||||
"""
|
||||
Evaluate the correctness of all sub-questions across the results.
|
||||
This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py
|
||||
This function calculates the total number of correct sub-questions and the overall
|
||||
number of sub-questions present in all results. It returns the ratio of correct
|
||||
sub-questions to the total number of sub-questions.
|
||||
|
||||
Args:
|
||||
results (dict): A collection of results where each result may contain a 'correctness' field.
|
||||
|
||||
Returns:
|
||||
float: The ratio of correct sub-questions, rounded to four decimal places.
|
||||
Returns 0 if there are no sub-questions.
|
||||
"""
|
||||
correct = sum(sum(result["correctness"].values()) for result in results if "correctness" in result)
|
||||
total = sum(len(result["correctness"]) for result in results if "correctness" in result)
|
||||
return round(correct / total, 4) if total > 0 else 0
|
||||
|
||||
|
||||
def evaluate_accuracy_proportional_by_sub_question_adjusted(results):
|
||||
def evaluate_accuracy_proportional_by_sub_question_adjusted(results: dict) -> float:
|
||||
"""
|
||||
Adjust the score based on the number of sub-questions in each result.
|
||||
This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py
|
||||
This function calculates a score for each result by considering the number of sub-questions
|
||||
it contains. Each sub-question is assigned a score of 1 divided by the number of sub-questions.
|
||||
The total score for each result is computed as the sum of all correct sub-questions multiplied
|
||||
by the score per sub-question. Finally, it returns the average score across all results.
|
||||
|
||||
Args:
|
||||
results (dict): A collection of results where each result may contain a 'correctness' field.
|
||||
|
||||
Returns:
|
||||
float: The average score across all results, rounded to four decimal places.
|
||||
Returns 0 if there are no results.
|
||||
"""
|
||||
total_score = 0
|
||||
for result in results:
|
||||
if "correctness" in result:
|
||||
|
|
@ -33,132 +77,23 @@ def evaluate_accuracy_proportional_by_sub_question_adjusted(results):
|
|||
return round(total_score / len(results), 4) if results else 0
|
||||
|
||||
|
||||
class DABench:
|
||||
def __init__(
|
||||
self,
|
||||
questions_file=Path(DABENCH_PATH) / "da-dev-questions.jsonl",
|
||||
answers_file=Path(DABENCH_PATH) / "da-dev-labels.jsonl",
|
||||
template="",
|
||||
):
|
||||
# Read questions from a JSONL file
|
||||
with open(questions_file, "r") as file:
|
||||
self.questions = {int(json.loads(line)["id"]): json.loads(line) for line in file}
|
||||
async def reformat(question: str, format: str, response: str) -> str:
|
||||
"""
|
||||
Asynchronously reformats a given response based on specified formatting requirements.
|
||||
This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/reformat.py
|
||||
This function constructs a prompt for the LLM (Large Language Model) to reformat
|
||||
the provided response according to the specified format. It includes a system prompt
|
||||
to guide the LLM's behavior and a template that outlines the expected output structure.
|
||||
|
||||
# Read answers from a JSONL file
|
||||
with open(answers_file, "r") as file:
|
||||
self.answers = {int(json.loads(line)["id"]): json.loads(line) for line in file}
|
||||
Args:
|
||||
question (str): The original question posed by the user.
|
||||
format (str): The specific formatting requirements that the response must adhere to.
|
||||
response (str): The initial response from the LLM that needs to be reformatted.
|
||||
|
||||
self.template = template if template else DABENCH
|
||||
|
||||
def get_question(self, question_id):
|
||||
"""Retrieve the question by its id."""
|
||||
return self.questions.get(question_id, "Question not found.")
|
||||
|
||||
def get_prompt(self, question_id):
|
||||
"""Retrieve the question by its id."""
|
||||
temp = self.get_question(question_id)
|
||||
return self.template.format(
|
||||
question=temp["question"],
|
||||
constraints=temp["constraints"],
|
||||
format=temp["format"],
|
||||
file_name=str(DABENCH_PATH) + "/da-dev-tables/" + temp["file_name"],
|
||||
level=temp["level"],
|
||||
)
|
||||
|
||||
def get_answer(self, answer_id):
|
||||
"""Retrieve the answer list by its id."""
|
||||
return self.answers.get(answer_id, "Answer not found.")
|
||||
|
||||
def eval(self, id: str, result: str) -> bool:
|
||||
"""Evaluate the prediction against the true label."""
|
||||
# prediction = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"]
|
||||
prediction = result
|
||||
true_label = self.get_answer(id)["common_answers"]
|
||||
nest_asyncio.apply()
|
||||
cleaned_prediction = prediction.replace("{", "").replace("}", "").replace("'", "")
|
||||
if cleaned_prediction: # Ensure it's not empty
|
||||
try:
|
||||
pred_dict = parse_prediction(cleaned_prediction)
|
||||
if compare_predictions(pred_dict, true_label):
|
||||
return (prediction, True)
|
||||
except:
|
||||
print("format errer, using gpt to refomat")
|
||||
|
||||
# If the cleaned prediction is not valid, try the async reformat
|
||||
try:
|
||||
prediction = asyncio.run(
|
||||
reformat(self.get_question(id)["question"], self.get_question(id)["format"], result)
|
||||
)
|
||||
try:
|
||||
prediction = prediction.split("Answer{{")[1].split("}}")[0].strip()
|
||||
except:
|
||||
pass
|
||||
pred_dict = parse_prediction(prediction)
|
||||
if compare_predictions(pred_dict, true_label):
|
||||
return (prediction, True)
|
||||
except Exception as e:
|
||||
print(f"Error during async reformat: {e}")
|
||||
# Skip this step if there's an error
|
||||
|
||||
return (prediction, False)
|
||||
|
||||
def eval_all(self, id_list, predictions):
|
||||
"""Evaluate all predictions and calculate accuracy rates."""
|
||||
|
||||
def sigle_eval(id, prediction):
|
||||
"""Evaluate the prediction against the true label for a single question and return a dictionary indicating the correctness of each metric."""
|
||||
true_label = self.get_answer(id)["common_answers"]
|
||||
prediction = prediction.replace("{", "").replace("}", "").replace("'", "")
|
||||
pred_dict = parse_prediction(prediction)
|
||||
# Initialize the correctness dictionary with False values
|
||||
correctness = {metric: False for metric, _ in true_label}
|
||||
# Check each metric's prediction against the true label
|
||||
for metric, true_value in true_label:
|
||||
try:
|
||||
true_value = float(true_value)
|
||||
except:
|
||||
true_value = true_value.replace(",", "")
|
||||
if metric in pred_dict:
|
||||
# Consider the prediction correct if it's within a small tolerance
|
||||
if (
|
||||
isinstance(true_value, (int, float))
|
||||
and isinstance(pred_dict[metric], (int, float))
|
||||
and abs(pred_dict[metric] - true_value) < 1e-6
|
||||
):
|
||||
correctness[metric] = True
|
||||
if isinstance(true_value, str) and (
|
||||
metric not in pred_dict or str(pred_dict[metric]).lower() != str(true_value).lower()
|
||||
):
|
||||
correctness[metric] = True
|
||||
return correctness
|
||||
|
||||
results = []
|
||||
for id, prediction in zip(id_list, predictions):
|
||||
correct = sigle_eval(id, prediction)
|
||||
results.append({"id": id, "correctness": correct})
|
||||
|
||||
# Calculate the three accuracy rates
|
||||
accuracy_by_question = evaluate_accuracy_by_question(results)
|
||||
accuracy_by_sub_question = evaluate_accuracy_by_sub_question(results)
|
||||
proportional_accuracy_by_sub_question = evaluate_accuracy_proportional_by_sub_question_adjusted(results)
|
||||
|
||||
return {
|
||||
"accuracy_by_question": accuracy_by_question,
|
||||
"accuracy_by_sub_question": accuracy_by_sub_question,
|
||||
"proportional_accuracy_by_sub_question": proportional_accuracy_by_sub_question,
|
||||
}
|
||||
|
||||
|
||||
async def ask_and_print(question, system_prompt):
|
||||
from metagpt.llm import LLM
|
||||
|
||||
llm = LLM()
|
||||
rsp = await llm.aask(question, system_msgs=[system_prompt])
|
||||
return rsp
|
||||
|
||||
|
||||
# This code is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/reformat.py
|
||||
async def reformat(question, format, response):
|
||||
Returns:
|
||||
str: The reformatted response generated by the LLM based on the provided question
|
||||
and formatting requirements.
|
||||
"""
|
||||
system_prompt = "You are a helpful assistant."
|
||||
demons = """\Format{{
|
||||
@shapiro_wilk_statistic[test_statistic]
|
||||
|
|
@ -183,65 +118,369 @@ async def reformat(question, format, response):
|
|||
Your answer should contain all the \"@answer_name[answer]\" in the order mentioned, each \"answer\" should be in the range of value as required. You need to keep the original numbers and text, just reformat without making any changes.
|
||||
The format requirements of this question is:
|
||||
{format}. You need to keep the original numbers and text, just reformat without making any changes. Please give your answer:"""
|
||||
# res = """[['monthly_avg_windspeed', "{'month_1': 7.17, 'month_2': 6.53, 'month_3': 5.9, 'month_4': 6.69, 'month_5': 5.43, 'month_6': 5.82, 'month_7': 5.13, 'month_8': 5.72, 'month_9': 5.69, 'month_10': 6.57, 'month_11': 5.79, 'month_12': 5.52}"]]}"""
|
||||
messages = [{"role": "user", "content": question}]
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
messages.append({"role": "user", "content": reformat_template.format(demons=demons, format=format)})
|
||||
rsp = await ask_and_print(messages, system_prompt)
|
||||
messages = [
|
||||
{"role": "user", "content": question},
|
||||
{"role": "assistant", "content": response},
|
||||
{"role": "user", "content": reformat_template.format(demons=demons, format=format)},
|
||||
]
|
||||
rsp = await ask(messages, system_prompt)
|
||||
return rsp
|
||||
|
||||
|
||||
def parse_prediction(prediction: str) -> dict:
|
||||
"""Parse the prediction string into a dictionary of metric-value pairs."""
|
||||
pred_dict = {}
|
||||
for pred in prediction.split("@"):
|
||||
if pred == "":
|
||||
continue
|
||||
temp = re.split(r"[\[\]]", pred.strip())
|
||||
temp = [s.replace(",", "") for s in temp]
|
||||
parts = [s for s in temp if s]
|
||||
metric = parts[0].strip().replace(",", "")
|
||||
value = parts[-1].replace(",", "").replace(":", "")
|
||||
def load_jsonl(file_path: Union[Path, str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load data from a JSONL file into a list of dictionaries.
|
||||
|
||||
try:
|
||||
value = float(value)
|
||||
except ValueError:
|
||||
pass # Keep value as string if conversion fails
|
||||
Args:
|
||||
file_path (Union[Path, str]): The path to the JSONL file to be loaded.
|
||||
|
||||
pred_dict[metric] = value
|
||||
return pred_dict
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries containing the data from the JSONL file.
|
||||
"""
|
||||
# Convert file_path to Path if it's a string
|
||||
if isinstance(file_path, str):
|
||||
file_path = Path(file_path)
|
||||
|
||||
data = []
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
for line in file:
|
||||
data.append(json.loads(line))
|
||||
return data
|
||||
|
||||
|
||||
def compare_predictions(pred_dict: dict, true_label: list) -> bool:
|
||||
"""Compare each prediction with the corresponding true label."""
|
||||
sorted_true_label = sorted(true_label, key=lambda x: x[0])
|
||||
"""
|
||||
Compares each prediction against the corresponding true label.
|
||||
|
||||
This function checks whether the predicted values match the true values for each
|
||||
metric. It sorts the true labels to ensure the comparison is made in the correct
|
||||
order. The function returns True if all predictions are accurate within a small
|
||||
tolerance for numerical values, or if string values match case-insensitively.
|
||||
|
||||
Args:
|
||||
pred_dict (dict): A dictionary of predicted metrics and their values.
|
||||
true_label (list): A list of tuples containing true metrics and their values.
|
||||
|
||||
Returns:
|
||||
bool: True if all predictions match the true labels, False otherwise.
|
||||
"""
|
||||
sorted_true_label = sorted(true_label, key=lambda x: x[0]) # Sort true labels by metric name
|
||||
|
||||
for metric, true_value in sorted_true_label:
|
||||
try:
|
||||
true_value = float(true_value)
|
||||
true_value = float(true_value) # Attempt to convert the true value to float
|
||||
except ValueError:
|
||||
true_value = true_value.replace(",", "")
|
||||
true_value = true_value.replace(",", "") # Clean the true value if conversion fails
|
||||
|
||||
# Check if the true value is numeric and compare with the prediction
|
||||
if isinstance(true_value, (int, float)) and (
|
||||
metric not in pred_dict or abs(pred_dict[metric] - true_value) > 1e-6
|
||||
):
|
||||
return False
|
||||
return False # Return False if the prediction is inaccurate
|
||||
|
||||
# Check if the true value is a string and compare with the prediction
|
||||
if isinstance(true_value, str) and (
|
||||
metric not in pred_dict or str(pred_dict[metric]).lower() != str(true_value).lower()
|
||||
):
|
||||
return False
|
||||
return False # Return False if the string prediction does not match
|
||||
|
||||
return True
|
||||
return True # Return True if all predictions are accurate
|
||||
|
||||
|
||||
async def ask(question: str, system_prompt: str) -> str:
|
||||
"""
|
||||
Asynchronously sends a question to the LLM (Large Language Model) and retrieves the response.
|
||||
|
||||
This function initializes an instance of the LLM and uses it to ask a question
|
||||
along with a system prompt. The response from the LLM is awaited and returned.
|
||||
|
||||
Args:
|
||||
question (str): The question to be asked to the LLM.
|
||||
system_prompt (str): A prompt that provides context or instructions to the LLM.
|
||||
|
||||
Returns:
|
||||
str: The response from the LLM based on the provided question and system prompt.
|
||||
"""
|
||||
from metagpt.llm import LLM # Importing the LLM class from the metagpt module
|
||||
|
||||
llm = LLM() # Create an instance of the LLM
|
||||
rsp = await llm.aask(question, system_msgs=[system_prompt]) # Await the response from the LLM
|
||||
return rsp # Return the response
|
||||
|
||||
|
||||
def parse_prediction(prediction: str) -> dict:
|
||||
"""
|
||||
Parses a prediction string into a dictionary of metric-value pairs.
|
||||
|
||||
This function takes a formatted string containing metrics and their corresponding
|
||||
values, separated by the "@" symbol. Each metric may be enclosed in brackets and
|
||||
may include commas. The function processes the input to extract and clean the
|
||||
metrics and their values, returning them in a structured dictionary format.
|
||||
|
||||
Args:
|
||||
prediction (str): A string representation of metrics and their values.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary where each key is a metric name and each value is the
|
||||
corresponding value, either as a float or a string.
|
||||
"""
|
||||
pred_dict = {}
|
||||
for pred in prediction.split("@"):
|
||||
if pred == "":
|
||||
continue # Skip any empty segments resulting from the split
|
||||
temp = re.split(r"[\[\]]", pred.strip()) # Split the string by brackets
|
||||
temp = [s.replace(",", "") for s in temp] # Remove commas from the segments
|
||||
parts = [s for s in temp if s] # Filter out any empty strings
|
||||
metric = parts[0].strip().replace(",", "") # Extract and clean the metric name
|
||||
value = parts[-1].replace(",", "").replace(":", "") # Extract and clean the value
|
||||
|
||||
try:
|
||||
value = float(value) # Attempt to convert the value to a float
|
||||
except ValueError:
|
||||
pass # If conversion fails, retain the value as a string
|
||||
|
||||
pred_dict[metric] = value # Store the metric-value pair in the dictionary
|
||||
return pred_dict
|
||||
|
||||
|
||||
class DABench:
|
||||
def __init__(
|
||||
self,
|
||||
questions_file: Path = Path(DABENCH_PATH) / "da-dev-questions.jsonl",
|
||||
answers_file: Path = Path(DABENCH_PATH) / "da-dev-labels.jsonl",
|
||||
template: str = "",
|
||||
):
|
||||
"""
|
||||
Initializes the DABench instance with questions and answers.
|
||||
|
||||
This constructor loads questions and answers from specified JSONL files.
|
||||
It also sets a template for formatting prompts. If no template is provided,
|
||||
a default template is used.
|
||||
|
||||
Args:
|
||||
questions_file (Path): The path to the JSONL file containing questions.
|
||||
answers_file (Path): The path to the JSONL file containing answers.
|
||||
template (str): A string template for formatting prompts.
|
||||
"""
|
||||
|
||||
self.questions = {
|
||||
int(line["id"]): line for line in load_jsonl(questions_file)
|
||||
} # Load questions from the specified file
|
||||
self.answers = {
|
||||
int(line["id"]): line for line in load_jsonl(answers_file)
|
||||
} # Load answers from the specified file
|
||||
self.template = template if template else DABENCH # Set the template, defaulting if necessary
|
||||
|
||||
def get_question(self, question_id: str) -> dict:
|
||||
"""
|
||||
Retrieve the question associated with the given ID.
|
||||
|
||||
This method looks up a question by its unique identifier. If the question
|
||||
is found, it returns the question data; otherwise, it returns a message
|
||||
indicating that the question was not found.
|
||||
|
||||
Args:
|
||||
question_id (str): The unique identifier for the question.
|
||||
|
||||
Returns:
|
||||
dict: The question data if found, otherwise a "Question not found." message.
|
||||
"""
|
||||
return self.questions.get(question_id, "Question not found.") # Return the question or an error message
|
||||
|
||||
def generate_formatted_prompt(self, question_id: str) -> str:
|
||||
"""
|
||||
Generate a formatted prompt for the specified question ID.
|
||||
|
||||
This method retrieves the question data and formats it using the specified
|
||||
template. The formatted prompt includes the question, constraints, format,
|
||||
file name, and level, allowing for a structured output.
|
||||
|
||||
Args:
|
||||
question_id (str): The unique identifier for the question.
|
||||
|
||||
Returns:
|
||||
str: A formatted prompt string based on the question data.
|
||||
"""
|
||||
temp = self.get_question(question_id) # Retrieve the question data
|
||||
return self.template.format(
|
||||
question=temp["question"],
|
||||
constraints=temp["constraints"],
|
||||
format=temp["format"],
|
||||
file_name=str(DABENCH_PATH) + "/da-dev-tables/" + temp["file_name"],
|
||||
level=temp["level"],
|
||||
) # Format and return the prompt
|
||||
|
||||
def get_answer(self, answer_id: str) -> list:
|
||||
"""
|
||||
Retrieve the answer list associated with the given ID.
|
||||
|
||||
This method looks up an answer by its unique identifier. If the answer
|
||||
is found, it returns the answer data; otherwise, it returns a message
|
||||
indicating that the answer was not found.
|
||||
|
||||
Args:
|
||||
answer_id (str): The unique identifier for the answer.
|
||||
|
||||
Returns:
|
||||
list: The answer data if found, otherwise an "Answer not found." message.
|
||||
"""
|
||||
return self.answers.get(answer_id, "Answer not found.") # Return the answer or an error message
|
||||
|
||||
@handle_exception(exception_msg="Error parsing cleaned prediction", default_return=(None, False))
|
||||
def parse_cleaned_prediction(self, cleaned_prediction: str, true_label: Any) -> Tuple[str, bool]:
|
||||
"""
|
||||
Parse the cleaned prediction and compare it with the true label.
|
||||
|
||||
Args:
|
||||
cleaned_prediction (str): The cleaned prediction string.
|
||||
true_label (Any): The true label to compare against.
|
||||
|
||||
Returns:
|
||||
Tuple[str, bool]: A tuple containing the cleaned prediction and a boolean indicating
|
||||
whether it matches the true label.
|
||||
"""
|
||||
if cleaned_prediction: # Ensure the cleaned prediction is not empty
|
||||
pred_dict = parse_prediction(cleaned_prediction) # Parse the prediction
|
||||
if pred_dict is not None and compare_predictions(pred_dict, true_label):
|
||||
return cleaned_prediction, True # Return if the prediction matches the true label
|
||||
return cleaned_prediction, False # Return the cleaned prediction with a False match
|
||||
|
||||
@handle_exception(exception_msg="Error during async reformat", default_return=(None, False))
|
||||
def async_reformat_prediction(self, id: str, result: str) -> str:
|
||||
"""
|
||||
Reformat the prediction asynchronously and extract the answer.
|
||||
|
||||
Args:
|
||||
id (str): The identifier for the question.
|
||||
result (str): The original prediction result.
|
||||
|
||||
Returns:
|
||||
str: The reformatted prediction or the original prediction if extraction fails.
|
||||
"""
|
||||
question = self.get_question(id)["question"] # Retrieve the question based on the ID
|
||||
question_format = self.get_question(id)["format"] # Get the format of the question
|
||||
prediction = asyncio.run(reformat(question, question_format, result)) # Asynchronously reformat the prediction
|
||||
|
||||
# Attempt to extract the answer from the reformatted prediction
|
||||
answer_part = prediction.split("Answer{{") if "Answer{{" in prediction else []
|
||||
if len(answer_part) > 1:
|
||||
return answer_part[1].split("}}")[0].strip() # Return the extracted answer
|
||||
|
||||
return prediction # If extraction fails, return the original prediction
|
||||
|
||||
def eval(self, id: str, result: str) -> Tuple[str, bool]:
|
||||
"""
|
||||
Evaluate the prediction against the true label.
|
||||
|
||||
Args:
|
||||
id (str): The identifier for the question.
|
||||
result (str): The original prediction result.
|
||||
|
||||
Returns:
|
||||
Tuple[str, bool]: A tuple containing the final prediction and a boolean indicating
|
||||
whether it matches the true label.
|
||||
"""
|
||||
true_label = self.get_answer(id)["common_answers"] # Retrieve the true label for comparison
|
||||
nest_asyncio.apply() # Apply nested asyncio to allow for async calls
|
||||
result = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"].strip()
|
||||
cleaned_prediction = result.replace("{", "").replace("}", "").replace("'", "") # Clean the prediction string
|
||||
|
||||
# Use the decorated function to handle exceptions while parsing the cleaned prediction
|
||||
parsed_result = self.parse_cleaned_prediction(cleaned_prediction, true_label)
|
||||
if parsed_result[1]: # If the parsed prediction is valid
|
||||
return parsed_result # Return the valid prediction
|
||||
|
||||
# If the cleaned prediction is not valid, attempt to asynchronously reformat it
|
||||
prediction = self.async_reformat_prediction(id, result)
|
||||
|
||||
pred_dict = parse_prediction(prediction) # Parse the reformatted prediction
|
||||
if pred_dict is not None and compare_predictions(pred_dict, true_label):
|
||||
return prediction, True # Return if the reformatted prediction matches the true label
|
||||
|
||||
return prediction, False # Return the final prediction with a False match
|
||||
|
||||
@handle_exception(exception_msg="Error evaluating single prediction", default_return={})
|
||||
def single_eval(self, id: str, prediction: str) -> dict:
|
||||
"""
|
||||
Evaluate the prediction against the true label for a single question.
|
||||
just using in eval_all
|
||||
|
||||
Args:
|
||||
id (str): The identifier for the question.
|
||||
prediction (str): The prediction string to evaluate.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary indicating the correctness of each metric.
|
||||
"""
|
||||
true_label = self.get_answer(id)["common_answers"] # Retrieve the true label for the question
|
||||
prediction = prediction.replace("{", "").replace("}", "").replace("'", "") # Clean the prediction string
|
||||
pred_dict = parse_prediction(prediction) # Parse the prediction into a dictionary
|
||||
|
||||
# Initialize the correctness dictionary with False values for each metric
|
||||
correctness = {metric: False for metric, _ in true_label}
|
||||
|
||||
# Check each metric's prediction against the true label
|
||||
for metric, true_value in true_label:
|
||||
try:
|
||||
true_value = float(true_value) # Attempt to convert the true value to float
|
||||
except ValueError:
|
||||
true_value = true_value.replace(",", "") # Handle non-numeric values
|
||||
|
||||
if metric in pred_dict:
|
||||
# Consider the prediction correct if it's within a small tolerance
|
||||
if (
|
||||
isinstance(true_value, (int, float))
|
||||
and isinstance(pred_dict[metric], (int, float))
|
||||
and abs(pred_dict[metric] - true_value) < 1e-6
|
||||
):
|
||||
correctness[metric] = True # Mark as correct if within tolerance
|
||||
|
||||
if isinstance(true_value, str) and (
|
||||
metric not in pred_dict or str(pred_dict[metric]).lower() != str(true_value).lower()
|
||||
):
|
||||
correctness[metric] = True # Mark as correct for string comparison
|
||||
|
||||
return correctness # Return the correctness dictionary
|
||||
|
||||
def eval_all(self, id_list: list, predictions: list) -> dict:
|
||||
"""
|
||||
Evaluate all predictions and calculate accuracy rates.
|
||||
|
||||
Args:
|
||||
id_list (list): A list of question identifiers.
|
||||
predictions (list): A list of prediction strings corresponding to the questions.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing accuracy rates by question and sub-question.
|
||||
"""
|
||||
results = [] # Initialize a list to store results for each question
|
||||
|
||||
# Evaluate each prediction against its corresponding question ID
|
||||
for id, prediction in zip(id_list, predictions):
|
||||
correct = self.single_eval(id, prediction) # Evaluate the single prediction
|
||||
results.append({"id": id, "correctness": correct}) # Append the result to the list
|
||||
|
||||
# Calculate the three accuracy rates based on the results
|
||||
accuracy_by_question = evaluate_accuracy_by_question(results)
|
||||
accuracy_by_sub_question = evaluate_accuracy_by_sub_question(results)
|
||||
proportional_accuracy_by_sub_question = evaluate_accuracy_proportional_by_sub_question_adjusted(results)
|
||||
|
||||
return {
|
||||
"accuracy_by_question": accuracy_by_question,
|
||||
"accuracy_by_sub_question": accuracy_by_sub_question,
|
||||
"proportional_accuracy_by_sub_question": proportional_accuracy_by_sub_question,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
DA = DABench()
|
||||
# id = [0, 5, 6]
|
||||
# prediction = [
|
||||
# "@mean_fare[34.89]",
|
||||
# "@correlation_coefficient[0.21]",
|
||||
# "@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]",
|
||||
# ]
|
||||
id = 760
|
||||
prediction = "@most_missing_station_name[AGE00135039]@most_missing_station_count[0]"
|
||||
id = 0
|
||||
prediction = "@mean_fare[34.65]"
|
||||
print(DA.eval(id, prediction))
|
||||
ids = [0, 5, 6]
|
||||
predictions = [
|
||||
"@mean_fare[34.89]",
|
||||
"@correlation_coefficient[0.21]",
|
||||
"@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]",
|
||||
]
|
||||
print(DA.eval_all(ids, predictions))
|
||||
|
|
|
|||
|
|
@ -1,13 +1,15 @@
|
|||
# InfiAgent-DABench
|
||||
This example is used to solve the InfiAgent-DABench using Data Interpreter (DI), and obtains 94.93% accuracy using gpt-4o.
|
||||
|
||||
## Dataset-install
|
||||
## Dataset
|
||||
```
|
||||
cd /examples/di/InfiAgent-DABench
|
||||
git clone https://github.com/InfiAgent/InfiAgent.git
|
||||
mv InfiAgent/examples/DA-Agent/data ./
|
||||
```
|
||||
## How to run
|
||||
```
|
||||
python run_InfiAgent-DABench_sigle.py --id x # run a task
|
||||
python run_InfiAgent-DABench_sigle.py --id x # run a task, x represents the id of the question you want to test
|
||||
python run_InfiAgent-DABench_all.py # Run all tasks serially
|
||||
python run_InfiAgent-DABench.py # Run all tasks in parallel
|
||||
python run_InfiAgent-DABench.py --k x # Run all tasks in parallel, x represents the number of parallel tasks at a time
|
||||
```
|
||||
|
|
@ -1,5 +1,3 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
|
|
@ -8,62 +6,70 @@ from DABench import DABench
|
|||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
|
||||
|
||||
def init_agent(*args, **kwargs):
|
||||
return
|
||||
async def get_prediction(agent, requirement):
|
||||
"""Helper function to obtain a prediction from a new instance of the agent.
|
||||
|
||||
This function runs the agent with the provided requirement and extracts the prediction
|
||||
from the result. If an error occurs during processing, it logs the error and returns None.
|
||||
|
||||
async def get_prediction(agent_class, requirement):
|
||||
"""Helper function to get prediction from a new instance of the agent"""
|
||||
Args:
|
||||
agent: The agent instance used to generate predictions.
|
||||
requirement: The input requirement for which the prediction is to be made.
|
||||
|
||||
Returns:
|
||||
The predicted result if successful, otherwise None.
|
||||
"""
|
||||
try:
|
||||
agent = agent_class # Instantiate the agent inside this function to avoid memory conflicts
|
||||
# Run the agent with the given requirement and await the result
|
||||
result = await agent.run(requirement)
|
||||
|
||||
# Parse the result to extract the prediction from the JSON response
|
||||
prediction_json = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])
|
||||
prediction = prediction_json[-1]["result"]
|
||||
return prediction
|
||||
prediction = prediction_json[-1]["result"] # Extract the last result from the parsed JSON
|
||||
|
||||
return prediction # Return the extracted prediction
|
||||
except Exception as e:
|
||||
# Log an error message if an exception occurs during processing
|
||||
print(f"Error processing requirement: {requirement}. Error: {e}")
|
||||
return None
|
||||
return None # Return None in case of an error
|
||||
|
||||
|
||||
async def evaluate_all(agent_class):
|
||||
"""Evaluate all tasks in DABench using the specified baseline agent"""
|
||||
DA = DABench()
|
||||
id_list, predictions = [], []
|
||||
tasks = []
|
||||
async def evaluate_all(agent, k):
|
||||
"""Evaluate all tasks in DABench using the specified baseline agent.
|
||||
|
||||
Tasks are divided into groups of size k and processed in parallel.
|
||||
|
||||
Args:
|
||||
agent: The baseline agent used for making predictions.
|
||||
k (int): The number of tasks to process in each group concurrently.
|
||||
"""
|
||||
DA = DABench() # Create an instance of DABench to access its methods and data
|
||||
id_list, predictions = [], [] # Initialize lists to store IDs and predictions
|
||||
tasks = [] # Initialize a list to hold the tasks
|
||||
|
||||
# Iterate over the answers in DABench to generate tasks
|
||||
for key, value in DA.answers.items():
|
||||
requirement = DA.get_prompt(key)
|
||||
tasks.append(get_prediction(agent_class, requirement))
|
||||
id_list.append(key)
|
||||
# Run all tasks concurrently
|
||||
predictions = await asyncio.gather(*tasks)
|
||||
# Filter out any None values in predictions
|
||||
predictions = [pred for pred in predictions if pred is not None]
|
||||
requirement = DA.generate_formatted_prompt(key) # Generate a formatted prompt for the current key
|
||||
tasks.append(get_prediction(agent, requirement)) # Append the prediction task to the tasks list
|
||||
id_list.append(key) # Append the current key to the ID list
|
||||
|
||||
# Process tasks in groups of size k and execute them concurrently
|
||||
for i in range(0, len(tasks), k):
|
||||
# Get the current group of tasks
|
||||
current_group = tasks[i : i + k]
|
||||
# Execute the current group of tasks in parallel
|
||||
group_predictions = await asyncio.gather(*current_group)
|
||||
# Filter out any None values from the predictions and extend the predictions list
|
||||
predictions.extend(pred for pred in group_predictions if pred is not None)
|
||||
|
||||
# Evaluate the results using all valid predictions and print the evaluation
|
||||
print(DA.eval_all(id_list, predictions))
|
||||
|
||||
|
||||
def main():
|
||||
# Set up argparse to handle command-line arguments
|
||||
parser = argparse.ArgumentParser(description="Run evaluation with different baselines.")
|
||||
# Define the command-line argument for the agent name
|
||||
parser.add_argument(
|
||||
"--agent_name",
|
||||
type=str,
|
||||
default="DataInterpreter",
|
||||
help="Specify the baseline agent class to use for evaluation.",
|
||||
)
|
||||
# Parse the arguments
|
||||
args = parser.parse_args()
|
||||
# Manually match the agent name to the class
|
||||
if args.agent_name == "DataInterpreter":
|
||||
agent_class = DataInterpreter()
|
||||
# Add more agents as needed
|
||||
# elif args.agent_name == "OtherAgent":
|
||||
# agent_class = OtherAgent
|
||||
else:
|
||||
print(f"Agent {args.agent_name} not recognized.")
|
||||
return
|
||||
# Run the evaluation with the specified agent class
|
||||
asyncio.run(evaluate_all(agent_class))
|
||||
def main(k=5):
|
||||
"""Main function to run the evaluation process."""
|
||||
agent = DataInterpreter() # Create an instance of the DataInterpreter agent
|
||||
asyncio.run(evaluate_all(agent, k)) # Run the evaluate_all function asynchronously
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -11,12 +11,11 @@ async def main():
|
|||
"""Evaluate all"""
|
||||
DA = DABench()
|
||||
id_list, predictions, labels, is_true = [], [], [], []
|
||||
|
||||
for key, value in DA.answers.items():
|
||||
id_list.append(key)
|
||||
labels.append(str(DA.get_answer(key)))
|
||||
try:
|
||||
requirement = DA.get_prompt(key)
|
||||
requirement = DA.generate_formatted_prompt(key)
|
||||
di = DataInterpreter()
|
||||
result = await di.run(requirement)
|
||||
logger.info(result)
|
||||
|
|
@ -24,13 +23,11 @@ async def main():
|
|||
temp_prediction, temp_istrue = DA.eval(key, str(result))
|
||||
is_true.append(str(temp_istrue))
|
||||
predictions.append(str(temp_prediction))
|
||||
|
||||
except:
|
||||
is_true.append(str(DA.eval(key, "")))
|
||||
predictions.append(str(""))
|
||||
df = pd.DataFrame({"Label": labels, "Prediction": predictions, "T/F": is_true})
|
||||
|
||||
df.to_excel("output.xlsx", index=False)
|
||||
df.to_excel("DABench_output.xlsx", index=False)
|
||||
print(DA.eval_all(id_list, predictions))
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,8 +7,9 @@ from metagpt.utils.recovery_util import save_history
|
|||
|
||||
|
||||
async def main(id=0):
|
||||
"""Evaluate one task"""
|
||||
DA = DABench()
|
||||
requirement = DA.get_prompt(id)
|
||||
requirement = DA.generate_formatted_prompt(id)
|
||||
di = DataInterpreter()
|
||||
result = await di.run(requirement)
|
||||
logger.info(result)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue