Refactor MCTS class to handle role running errors and improve error logging

This commit is contained in:
Yizhou Chi 2024-09-05 10:12:37 +08:00
parent 58d7b14007
commit c16286a006
2 changed files with 15 additions and 6 deletions

View file

@ -3,6 +3,7 @@ import os
import pickle
import random
import numpy as np
import pandas as pd
from expo.dataset import generate_task_requirement, get_split_dataset_path
@ -209,16 +210,20 @@ class Node:
await role.run(with_message="continue")
else:
await role.run(with_message=self.state["requirement"])
score_dict = await role.get_score()
score_dict = self.evaluate_simulation(score_dict)
self.raw_reward = score_dict
run_finished = True
except Exception as e:
mcts_logger.log("MCTS", f"Error in running the role: {e}")
num_runs += 1
if not run_finished:
mcts_logger.log("MCTS", f"Role {role.node_id} failed to run")
return {"test_score": 0, "dev_score": 0, "score": 0}
score_dict = await role.get_score()
score_dict = self.evaluate_simulation(score_dict)
self.raw_reward = score_dict
if self.state["low_is_better"]:
score_dict = {"test_score": np.inf, "dev_score": np.inf, "score": np.inf}
else:
score_dict = {"test_score": 0, "dev_score": 0, "score": 0}
self.raw_reward = score_dict
if self.state["low_is_better"]:
# normalized the score to be between 0 and 1, and higher is better
def normalize_score(score):

View file

@ -54,11 +54,15 @@ class Experimenter:
{"idx": i, "score_dict": score_dict, "user_requirement": user_requirement, "args": vars(self.args)}
)
self.save_result(results) # save intermediate results
dev_scores = [result["score_dict"]["dev_score"] for result in results]
dev_scores = [
result["score_dict"]["dev_score"] for result in results if result["score_dict"]["dev_score"] != -1
]
best_dev_score = max(dev_scores) if not self.args.low_is_better else min(dev_scores)
best_score_idx = dev_scores.index(best_dev_score)
test_scores = [result["score_dict"]["test_score"] for result in results]
test_scores = [
result["score_dict"]["test_score"] for result in results if result["score_dict"]["dev_score"] != -1
]
avg_score = sum(test_scores) / len(test_scores)
global_best_score = max(test_scores) if not self.args.low_is_better else min(test_scores)