fix mcts bug

This commit is contained in:
Yizhou Chi 2024-09-06 10:42:24 +08:00
parent 0e27e3d8be
commit e07ed0df8b

View file

@ -191,6 +191,7 @@ class Node:
def evaluate_simulation(self, score_dict):
scores = {"dev_score": self.evaluate_prediction("dev"), "test_score": self.evaluate_prediction("test")}
scores["score"] = scores["dev_score"]
score_dict.update(scores)
return score_dict
@ -345,7 +346,7 @@ class MCTS:
if node.raw_value == 0:
reward = await self.simulate(node)
else:
reward = {"test_score": node.raw_value, "score": node.value}
reward = {"test_score": node.raw_value, "score": node.raw_reward["score"]}
mcts_logger.log("MCTS", f"Terminal node's reward: {reward}")
self.backpropagate(node, reward)
else: