From 510136ab17e32e05d5a443fe832d57a7ba154605 Mon Sep 17 00:00:00 2001 From: Yizhou Chi Date: Thu, 17 Oct 2024 10:37:16 +0800 Subject: [PATCH] allowing whether to show instructions --- expo/evaluation/visualize_mcts.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/expo/evaluation/visualize_mcts.py b/expo/evaluation/visualize_mcts.py index e429789fd..6a8869670 100644 --- a/expo/evaluation/visualize_mcts.py +++ b/expo/evaluation/visualize_mcts.py @@ -78,7 +78,7 @@ def get_node_color(node): return f"#{red:02X}{green:02X}00" -def visualize_tree(graph, save_path=""): +def visualize_tree(graph, show_instructions=False, save_path=""): # Use a hierarchical layout for tree-like visualization pos = nx.spring_layout(graph, k=0.9, iterations=50) @@ -113,11 +113,12 @@ def visualize_tree(graph, save_path=""): labels = nx.get_node_attributes(graph, "label") nx.draw_networkx_labels(graph, pos, labels, font_size=NODE_FONT_SIZE) - # Add instructions to the right side of nodes - instructions = nx.get_node_attributes(graph, "instruction") - for node, (x, y) in pos.items(): - wrapped_text = textwrap.fill(instructions[node], width=30) # Adjust width as needed - plt.text(x + 0.05, y, wrapped_text, fontsize=15, ha="left", va="center") + if show_instructions: + # Add instructions to the right side of nodes + instructions = nx.get_node_attributes(graph, "instruction") + for node, (x, y) in pos.items(): + wrapped_text = textwrap.fill(instructions[node], width=30) # Adjust width as needed + plt.text(x + 0.05, y, wrapped_text, fontsize=15, ha="left", va="center") plt.title("MCTS Tree Visualization", fontsize=40) plt.axis("off") # Turn off axis