update mcts logic

This commit is contained in:
Yizhou Chi 2024-09-09 14:52:25 +08:00
parent 401ca97846
commit 9ba0d217fc

View file

@ -279,7 +279,8 @@ class MCTS:
def best_path(self, root: Node):
best_child = root
best_score = 0
global_best_score = root.normalized_reward["test_score"]
dev_best_score = root.normalized_reward["dev_score"]
def bfs(node: Node, best_score, best_child: Node, split):
assert split in ["test_score", "dev_score"]
@ -294,10 +295,10 @@ class MCTS:
best_score, best_child = bfs(child, best_score, best_child, split)
return best_score, best_child
_, best_child = bfs(root, best_score, best_child, "test_score")
_, dev_best_child = bfs(root, best_score, best_child, "dev_score")
_, global_best_child = bfs(root, global_best_score, best_child, "test_score")
_, dev_best_child = bfs(root, dev_best_score, best_child, "dev_score")
return {"dev_best": dev_best_child, "global_best": best_child}
return {"dev_best": dev_best_child, "global_best": global_best_child}
def get_num_simulations(self):
return self.root_node.visited