diff --git a/expo/MCTS.py b/expo/MCTS.py index b2ad824e5..365315330 100644 --- a/expo/MCTS.py +++ b/expo/MCTS.py @@ -191,6 +191,7 @@ class Node: def evaluate_simulation(self, score_dict): scores = {"dev_score": self.evaluate_prediction("dev"), "test_score": self.evaluate_prediction("test")} + scores["score"] = scores["dev_score"] score_dict.update(scores) return score_dict @@ -345,7 +346,7 @@ class MCTS: if node.raw_value == 0: reward = await self.simulate(node) else: - reward = {"test_score": node.raw_value, "score": node.value} + reward = {"test_score": node.raw_value, "score": node.raw_reward["score"]} mcts_logger.log("MCTS", f"Terminal node's reward: {reward}") self.backpropagate(node, reward) else: