From c16286a006c5d85c8b9f34c0183cde4c2c9849e8 Mon Sep 17 00:00:00 2001 From: Yizhou Chi Date: Thu, 5 Sep 2024 10:12:37 +0800 Subject: [PATCH] Refactor MCTS class to handle role running errors and improve error logging --- expo/MCTS.py | 13 +++++++++---- expo/experimenter/experimenter.py | 8 ++++++-- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/expo/MCTS.py b/expo/MCTS.py index 7c03e2e86..b2ad824e5 100644 --- a/expo/MCTS.py +++ b/expo/MCTS.py @@ -3,6 +3,7 @@ import os import pickle import random +import numpy as np import pandas as pd from expo.dataset import generate_task_requirement, get_split_dataset_path @@ -209,16 +210,20 @@ class Node: await role.run(with_message="continue") else: await role.run(with_message=self.state["requirement"]) + score_dict = await role.get_score() + score_dict = self.evaluate_simulation(score_dict) + self.raw_reward = score_dict run_finished = True except Exception as e: mcts_logger.log("MCTS", f"Error in running the role: {e}") num_runs += 1 if not run_finished: mcts_logger.log("MCTS", f"Role {role.node_id} failed to run") - return {"test_score": 0, "dev_score": 0, "score": 0} - score_dict = await role.get_score() - score_dict = self.evaluate_simulation(score_dict) - self.raw_reward = score_dict + if self.state["low_is_better"]: + score_dict = {"test_score": np.inf, "dev_score": np.inf, "score": np.inf} + else: + score_dict = {"test_score": 0, "dev_score": 0, "score": 0} + self.raw_reward = score_dict if self.state["low_is_better"]: # normalized the score to be between 0 and 1, and higher is better def normalize_score(score): diff --git a/expo/experimenter/experimenter.py b/expo/experimenter/experimenter.py index 83dde80b9..4161aef3d 100644 --- a/expo/experimenter/experimenter.py +++ b/expo/experimenter/experimenter.py @@ -54,11 +54,15 @@ class Experimenter: {"idx": i, "score_dict": score_dict, "user_requirement": user_requirement, "args": vars(self.args)} ) self.save_result(results) # save intermediate results - dev_scores = [result["score_dict"]["dev_score"] for result in results] + dev_scores = [ + result["score_dict"]["dev_score"] for result in results if result["score_dict"]["dev_score"] != -1 + ] best_dev_score = max(dev_scores) if not self.args.low_is_better else min(dev_scores) best_score_idx = dev_scores.index(best_dev_score) - test_scores = [result["score_dict"]["test_score"] for result in results] + test_scores = [ + result["score_dict"]["test_score"] for result in results if result["score_dict"]["dev_score"] != -1 + ] avg_score = sum(test_scores) / len(test_scores) global_best_score = max(test_scores) if not self.args.low_is_better else min(test_scores)