From e07ed0df8bcee81cd7889964d01a45cba5d92fd2 Mon Sep 17 00:00:00 2001 From: Yizhou Chi Date: Fri, 6 Sep 2024 10:42:24 +0800 Subject: [PATCH] fix mcts bug --- expo/MCTS.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/expo/MCTS.py b/expo/MCTS.py index b2ad824e5..365315330 100644 --- a/expo/MCTS.py +++ b/expo/MCTS.py @@ -191,6 +191,7 @@ class Node: def evaluate_simulation(self, score_dict): scores = {"dev_score": self.evaluate_prediction("dev"), "test_score": self.evaluate_prediction("test")} + scores["score"] = scores["dev_score"] score_dict.update(scores) return score_dict @@ -345,7 +346,7 @@ class MCTS: if node.raw_value == 0: reward = await self.simulate(node) else: - reward = {"test_score": node.raw_value, "score": node.value} + reward = {"test_score": node.raw_value, "score": node.raw_reward["score"]} mcts_logger.log("MCTS", f"Terminal node's reward: {reward}") self.backpropagate(node, reward) else: