From 9ba0d217fc25ae4732394173a1c3626de5679128 Mon Sep 17 00:00:00 2001 From: Yizhou Chi Date: Mon, 9 Sep 2024 14:52:25 +0800 Subject: [PATCH] update mcts logic --- expo/MCTS.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/expo/MCTS.py b/expo/MCTS.py index 4090331cd..360baac8d 100644 --- a/expo/MCTS.py +++ b/expo/MCTS.py @@ -279,7 +279,8 @@ class MCTS: def best_path(self, root: Node): best_child = root - best_score = 0 + global_best_score = root.normalized_reward["test_score"] + dev_best_score = root.normalized_reward["dev_score"] def bfs(node: Node, best_score, best_child: Node, split): assert split in ["test_score", "dev_score"] @@ -294,10 +295,10 @@ class MCTS: best_score, best_child = bfs(child, best_score, best_child, split) return best_score, best_child - _, best_child = bfs(root, best_score, best_child, "test_score") - _, dev_best_child = bfs(root, best_score, best_child, "dev_score") + _, global_best_child = bfs(root, global_best_score, best_child, "test_score") + _, dev_best_child = bfs(root, dev_best_score, best_child, "dev_score") - return {"dev_best": dev_best_child, "global_best": best_child} + return {"dev_best": dev_best_child, "global_best": global_best_child} def get_num_simulations(self): return self.root_node.visited