add visit order

This commit is contained in:
garylin2099 2024-10-21 10:13:36 +00:00
parent d304fc36de
commit 3a8fdc6a5e
2 changed files with 7 additions and 4 deletions

View file

@ -128,7 +128,7 @@ def visualize_tree(graph, show_instructions=False, save_path=""):
plt.show()
def build_tree_recursive(graph, parent_id, node, start_task_id=2):
def build_tree_recursive(graph, parent_id, node, node_order, start_task_id=2):
"""
Recursively builds the entire tree starting from the root node.
Adds nodes and edges to the NetworkX graph.
@ -143,9 +143,10 @@ def build_tree_recursive(graph, parent_id, node, start_task_id=2):
# Add the current node with attributes to the graph
dev_score = node.raw_reward.get("dev_score", 0) * 100
avg_score = node.avg_value() * 100
order = node_order.index(node.id) if node.id in node_order else ""
graph.add_node(
parent_id,
label=f"{node.id}\nAvg: {avg_score:.1f}\nScore: {dev_score:.1f}\nVisits: {node.visited}",
label=f"{node.id}\nAvg: {avg_score:.1f}\nScore: {dev_score:.1f}\nVisits: {node.visited}\nOrder: {order}",
avg_value=node.avg_value(),
dev_score=dev_score,
visits=node.visited,
@ -159,4 +160,4 @@ def build_tree_recursive(graph, parent_id, node, start_task_id=2):
for i, child in enumerate(node.children):
child_id = f"{parent_id}-{i}"
graph.add_edge(parent_id, child_id)
build_tree_recursive(graph, child_id, child)
build_tree_recursive(graph, child_id, child, node_order)

View file

@ -17,7 +17,9 @@ if __name__ == "__main__":
)
mcts.load_tree()
mcts.load_node_order()
root = mcts.root_node
node_order = mcts.node_order
G = nx.DiGraph()
build_tree_recursive(G, "0", root)
build_tree_recursive(G, "0", root, node_order)
visualize_tree(G, save_path=f"results/{args.task}-tree.png")