add step score

This commit is contained in:
Yizhou Chi 2024-09-26 20:25:36 +08:00
parent 68c672d438
commit 2b67355358
2 changed files with 32 additions and 1 deletions

View file

@ -1,3 +1,4 @@
import json
import math
import os
import pickle
@ -240,6 +241,7 @@ class MCTS:
max_depth: int = 5
c_explore: float = 1.4
c_unvisited: float = 0.8
node_order: list = []
def __init__(self, root_node, max_depth, use_fixed_insights):
self.root_node = root_node
@ -306,11 +308,32 @@ class MCTS:
_, global_best_child = bfs(root, global_best_score, best_child, "test_score")
_, dev_best_child = bfs(root, dev_best_score, best_child, "dev_score")
return {"dev_best": dev_best_child, "global_best": global_best_child}
return {"dev_best": dev_best_child, "global_best": global_best_child, "scores": self.get_score_order_dict()}
def get_num_simulations(self):
return self.root_node.visited
def save_node_order(self, node_id):
self.node_order.append(node_id)
with open(os.path.join(self.root_node.state["node_dir"], "node_order.json"), "w") as f:
json.dump(self.node_order, f)
def load_node_order(self):
with open(os.path.join(self.root_node.state["node_dir"], "node_order.json"), "r") as f:
self.node_order = json.load(f)
def get_score_order_dict(self):
scores = {"dev": [], "test": [], "dev_raw": [], "test_raw": []}
for node_id in self.node_order:
node = Node(parent=None, state=self.root_node.state, action=None, value=0)
node.id = node_id
node = node.load_node()
scores["dev"].append(node.normalized_reward["dev_score"])
scores["test"].append(node.normalized_reward["test_score"])
scores["dev_raw"].append(node.raw_reward["dev_score"])
scores["test_raw"].append(node.raw_reward["test_score"])
return scores
async def search(self, state, rollouts, load_tree=False, reflection=False):
role, root = initialize_di_root_node(state, reflection=reflection)
self.root_node = root
@ -329,8 +352,12 @@ class MCTS:
self.backpropagate(root, reward)
node, reward = await self.expand_and_simulate(root)
# self.backpropagate(node, reward)
self.save_node_order(root.id)
self.save_node_order(node.id)
else:
root = self.root_node
self.load_node_order()
for _ in range(rollouts): # number of rollouts
mcts_logger.log("MCTS", f"Start the next rollout {_+1}")
node = self.select(root)
@ -344,6 +371,7 @@ class MCTS:
else:
node, reward = await self.expand_and_simulate(node)
# self.backpropagate(node, reward)
self.save_node_order(node.id)
return self.best_path(root)
async def expand_and_simulate(self, node):
@ -373,6 +401,7 @@ class MCTS:
self.root_node = pickle.load(f)
self.children[self.root_node] = self.root_node.children
load_children_node(self.root_node)
if self.children:
return True
return False

View file

@ -29,6 +29,7 @@ class MCTSExperimenter(Experimenter):
)
best_node = best_nodes["global_best"]
dev_best_node = best_nodes["dev_best"]
score_dict = best_nodes["scores"]
self.copy_notebook(best_node, "best")
self.copy_notebook(dev_best_node, "dev_best")
@ -50,6 +51,7 @@ class MCTSExperimenter(Experimenter):
"user_requirement": best_node.state["requirement"],
"tree_text": text,
"args": vars(self.args),
"scores": score_dict,
}
]
self.save_result(results)