diff --git a/expo/MCTS.py b/expo/MCTS.py index 4090331cd..360baac8d 100644 --- a/expo/MCTS.py +++ b/expo/MCTS.py @@ -279,7 +279,8 @@ class MCTS: def best_path(self, root: Node): best_child = root - best_score = 0 + global_best_score = root.normalized_reward["test_score"] + dev_best_score = root.normalized_reward["dev_score"] def bfs(node: Node, best_score, best_child: Node, split): assert split in ["test_score", "dev_score"] @@ -294,10 +295,10 @@ class MCTS: best_score, best_child = bfs(child, best_score, best_child, split) return best_score, best_child - _, best_child = bfs(root, best_score, best_child, "test_score") - _, dev_best_child = bfs(root, best_score, best_child, "dev_score") + _, global_best_child = bfs(root, global_best_score, best_child, "test_score") + _, dev_best_child = bfs(root, dev_best_score, best_child, "dev_score") - return {"dev_best": dev_best_child, "global_best": best_child} + return {"dev_best": dev_best_child, "global_best": global_best_child} def get_num_simulations(self): return self.root_node.visited