diff --git a/expo/MCTS.py b/expo/MCTS.py index aa4ade944..4564cd682 100644 --- a/expo/MCTS.py +++ b/expo/MCTS.py @@ -1,3 +1,4 @@ +import json import math import os import pickle @@ -240,6 +241,7 @@ class MCTS: max_depth: int = 5 c_explore: float = 1.4 c_unvisited: float = 0.8 + node_order: list = [] def __init__(self, root_node, max_depth, use_fixed_insights): self.root_node = root_node @@ -306,11 +308,32 @@ class MCTS: _, global_best_child = bfs(root, global_best_score, best_child, "test_score") _, dev_best_child = bfs(root, dev_best_score, best_child, "dev_score") - return {"dev_best": dev_best_child, "global_best": global_best_child} + return {"dev_best": dev_best_child, "global_best": global_best_child, "scores": self.get_score_order_dict()} def get_num_simulations(self): return self.root_node.visited + def save_node_order(self, node_id): + self.node_order.append(node_id) + with open(os.path.join(self.root_node.state["node_dir"], "node_order.json"), "w") as f: + json.dump(self.node_order, f) + + def load_node_order(self): + with open(os.path.join(self.root_node.state["node_dir"], "node_order.json"), "r") as f: + self.node_order = json.load(f) + + def get_score_order_dict(self): + scores = {"dev": [], "test": [], "dev_raw": [], "test_raw": []} + for node_id in self.node_order: + node = Node(parent=None, state=self.root_node.state, action=None, value=0) + node.id = node_id + node = node.load_node() + scores["dev"].append(node.normalized_reward["dev_score"]) + scores["test"].append(node.normalized_reward["test_score"]) + scores["dev_raw"].append(node.raw_reward["dev_score"]) + scores["test_raw"].append(node.raw_reward["test_score"]) + return scores + async def search(self, state, rollouts, load_tree=False, reflection=False): role, root = initialize_di_root_node(state, reflection=reflection) self.root_node = root @@ -329,8 +352,12 @@ class MCTS: self.backpropagate(root, reward) node, reward = await self.expand_and_simulate(root) # self.backpropagate(node, reward) + self.save_node_order(root.id) + self.save_node_order(node.id) else: root = self.root_node + self.load_node_order() + for _ in range(rollouts): # number of rollouts mcts_logger.log("MCTS", f"Start the next rollout {_+1}") node = self.select(root) @@ -344,6 +371,7 @@ class MCTS: else: node, reward = await self.expand_and_simulate(node) # self.backpropagate(node, reward) + self.save_node_order(node.id) return self.best_path(root) async def expand_and_simulate(self, node): @@ -373,6 +401,7 @@ class MCTS: self.root_node = pickle.load(f) self.children[self.root_node] = self.root_node.children load_children_node(self.root_node) + if self.children: return True return False diff --git a/expo/experimenter/mcts.py b/expo/experimenter/mcts.py index 5fb00ca8d..bd803bff1 100644 --- a/expo/experimenter/mcts.py +++ b/expo/experimenter/mcts.py @@ -29,6 +29,7 @@ class MCTSExperimenter(Experimenter): ) best_node = best_nodes["global_best"] dev_best_node = best_nodes["dev_best"] + score_dict = best_nodes["scores"] self.copy_notebook(best_node, "best") self.copy_notebook(dev_best_node, "dev_best") @@ -50,6 +51,7 @@ class MCTSExperimenter(Experimenter): "user_requirement": best_node.state["requirement"], "tree_text": text, "args": vars(self.args), + "scores": score_dict, } ] self.save_result(results)