Signed-off-by: kit <101046518@qq.com>
This commit is contained in:
kit 2024-10-28 22:02:12 +08:00
parent a7fa56a75b
commit b97fa40e0e
5 changed files with 22 additions and 20 deletions

View file

@ -8,6 +8,7 @@ import nest_asyncio
from examples.di.requirements_prompt import DABENCH
from metagpt.const import DABENCH_PATH
from metagpt.logs import logger
from metagpt.utils.exceptions import handle_exception
@ -473,14 +474,14 @@ class DABench:
if __name__ == "__main__":
DA = DABench()
bench = DABench()
id = 0
prediction = "@mean_fare[34.65]"
print(DA.eval(id, prediction))
logger.info(bench.eval(id, prediction))
ids = [0, 5, 6]
predictions = [
"@mean_fare[34.89]",
"@correlation_coefficient[0.21]",
"@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]",
]
print(DA.eval_all(ids, predictions))
logger.info(bench.eval_all(ids, predictions))

View file

@ -9,7 +9,7 @@ ## Dataset
```
## How to run
```
python run_InfiAgent-DABench_sigle.py --id x # run a task, x represents the id of the question you want to test
python run_InfiAgent-DABench_single.py --id x # run a task, x represents the id of the question you want to test
python run_InfiAgent-DABench_all.py # Run all tasks serially
python run_InfiAgent-DABench.py --k x # Run all tasks in parallel, x represents the number of parallel tasks at a time
```

View file

@ -3,6 +3,7 @@ import json
from DABench import DABench
from metagpt.logs import logger
from metagpt.roles.di.data_interpreter import DataInterpreter
@ -30,7 +31,7 @@ async def get_prediction(agent, requirement):
return prediction # Return the extracted prediction
except Exception as e:
# Log an error message if an exception occurs during processing
print(f"Error processing requirement: {requirement}. Error: {e}")
logger.info(f"Error processing requirement: {requirement}. Error: {e}")
return None # Return None in case of an error
@ -43,13 +44,13 @@ async def evaluate_all(agent, k):
agent: The baseline agent used for making predictions.
k (int): The number of tasks to process in each group concurrently.
"""
DA = DABench() # Create an instance of DABench to access its methods and data
bench = DABench() # Create an instance of DABench to access its methods and data
id_list, predictions = [], [] # Initialize lists to store IDs and predictions
tasks = [] # Initialize a list to hold the tasks
# Iterate over the answers in DABench to generate tasks
for key, value in DA.answers.items():
requirement = DA.generate_formatted_prompt(key) # Generate a formatted prompt for the current key
for key, value in bench.answers.items():
requirement = bench.generate_formatted_prompt(key) # Generate a formatted prompt for the current key
tasks.append(get_prediction(agent, requirement)) # Append the prediction task to the tasks list
id_list.append(key) # Append the current key to the ID list
@ -62,8 +63,8 @@ async def evaluate_all(agent, k):
# Filter out any None values from the predictions and extend the predictions list
predictions.extend(pred for pred in group_predictions if pred is not None)
# Evaluate the results using all valid predictions and print the evaluation
print(DA.eval_all(id_list, predictions))
# Evaluate the results using all valid predictions and logger.info the evaluation
logger.info(bench.eval_all(id_list, predictions))
def main(k=5):

View file

@ -9,26 +9,26 @@ from metagpt.utils.recovery_util import save_history
async def main():
"""Evaluate all"""
DA = DABench()
bench = DABench()
id_list, predictions, labels, is_true = [], [], [], []
for key, value in DA.answers.items():
for key, value in bench.answers.items():
id_list.append(key)
labels.append(str(DA.get_answer(key)))
labels.append(str(bench.get_answer(key)))
try:
requirement = DA.generate_formatted_prompt(key)
requirement = bench.generate_formatted_prompt(key)
di = DataInterpreter()
result = await di.run(requirement)
logger.info(result)
save_history(role=di)
temp_prediction, temp_istrue = DA.eval(key, str(result))
temp_prediction, temp_istrue = bench.eval(key, str(result))
is_true.append(str(temp_istrue))
predictions.append(str(temp_prediction))
except:
is_true.append(str(DA.eval(key, "")))
is_true.append(str(bench.eval(key, "")))
predictions.append(str(""))
df = pd.DataFrame({"Label": labels, "Prediction": predictions, "T/F": is_true})
df.to_excel("DABench_output.xlsx", index=False)
logger.info(DA.eval_all(id_list, predictions))
logger.info(bench.eval_all(id_list, predictions))
if __name__ == "__main__":

View file

@ -8,13 +8,13 @@ from metagpt.utils.recovery_util import save_history
async def main(id=0):
"""Evaluate one task"""
DA = DABench()
requirement = DA.generate_formatted_prompt(id)
bench = DABench()
requirement = bench.generate_formatted_prompt(id)
di = DataInterpreter()
result = await di.run(requirement)
logger.info(result)
save_history(role=di)
_, is_correct = DA.eval(id, str(result))
_, is_correct = bench.eval(id, str(result))
logger.info(f"Prediction is {'correct' if is_correct else 'incorrect'}.")