add tree visualization script and function

This commit is contained in:
Yizhou Chi 2024-10-17 10:27:33 +08:00
parent a46f575361
commit 0f01c07b83
2 changed files with 128 additions and 3 deletions

View file

@ -1,5 +1,8 @@
import textwrap
import matplotlib.pyplot as plt
import networkx as nx
from expo.MCTS import Node
NODE_TEMPLATE = """\
@ -11,6 +14,9 @@ Score: {score}, Visits: {num_visits}
"""
NODE_SIZE = 12000
NODE_FONT_SIZE = 18
def get_role_plans(role):
plans = role.planner.plan.tasks
@ -42,7 +48,7 @@ def get_tree_text(node: Node):
id=node_id, plans=instruct_plans_text, simulated=simulated, score=score, num_visits=num_visits
)
def visualize_tree(node, depth=0, previous_plans=None):
def visualize_tree_text(node, depth=0, previous_plans=None):
text = ""
if node is not None:
text += visualize_node(node, previous_plans)
@ -50,10 +56,106 @@ def get_tree_text(node: Node):
code_set.update({task.instruction for task in role.planner.plan.tasks})
previous_plans = get_role_plans(role)
for child in node.children:
text += textwrap.indent(visualize_tree(child, depth + 1, previous_plans), "\t")
text += textwrap.indent(visualize_tree_text(child, depth + 1, previous_plans), "\t")
return text
num_simulations = node.visited
text = f"Number of simulations: {num_simulations}\n"
text += visualize_tree(node)
text += visualize_tree_text(node)
return text, len(code_set)
def get_node_color(node):
if node["visits"] == 0:
return "#D3D3D3"
else:
# The higher the avg_value, the more intense the color
# avg_value is between 0 and 1
avg_value = node["avg_value"]
# Convert avg_value to a color ranging from red (low) to green (high)
red = int(255 * (1 - avg_value))
green = int(255 * avg_value)
return f"#{red:02X}{green:02X}00"
def visualize_tree(graph, save_path=""):
# Use a hierarchical layout for tree-like visualization
pos = nx.spring_layout(graph, k=0.9, iterations=50)
plt.figure(figsize=(30, 20)) # Further increase figure size for better visibility
# Calculate node levels
root = "0"
levels = nx.single_source_shortest_path_length(graph, root)
max_level = max(levels.values())
# Adjust y-coordinates based on levels and x-coordinates to prevent overlap
nodes_by_level = {}
for node, level in levels.items():
if level not in nodes_by_level:
nodes_by_level[level] = []
nodes_by_level[level].append(node)
for level, nodes in nodes_by_level.items():
y = 1 - level / max_level
x_step = 1.0 / (len(nodes) + 1)
for i, node in enumerate(sorted(nodes)):
pos[node] = ((i + 1) * x_step, y)
# Draw edges
nx.draw_networkx_edges(graph, pos, edge_color="gray", arrows=True, arrowsize=40, width=3)
# Draw nodes
node_colors = [get_node_color(graph.nodes[node]) for node in graph.nodes]
nx.draw_networkx_nodes(graph, pos, node_size=NODE_SIZE, node_color=node_colors)
# Add labels to nodes
labels = nx.get_node_attributes(graph, "label")
nx.draw_networkx_labels(graph, pos, labels, font_size=NODE_FONT_SIZE)
# Add instructions to the right side of nodes
instructions = nx.get_node_attributes(graph, "instruction")
for node, (x, y) in pos.items():
wrapped_text = textwrap.fill(instructions[node], width=30) # Adjust width as needed
plt.text(x + 0.05, y, wrapped_text, fontsize=15, ha="left", va="center")
plt.title("MCTS Tree Visualization", fontsize=40)
plt.axis("off") # Turn off axis
plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.show()
def build_tree_recursive(graph, parent_id, node, start_task_id=2):
"""
Recursively builds the entire tree starting from the root node.
Adds nodes and edges to the NetworkX graph.
"""
role = node.load_role()
depth = node.get_depth()
if depth == 0:
instruction = "\n\n".join([role.planner.plan.tasks[i].instruction for i in range(start_task_id)])
else:
instruction = role.planner.plan.tasks[depth + start_task_id - 1].instruction
print(instruction)
# Add the current node with attributes to the graph
dev_score = node.raw_reward.get("dev_score", 0) * 100
avg_score = node.avg_value() * 100
graph.add_node(
parent_id,
label=f"{node.id}\nAvg: {avg_score:.1f}\nScore: {dev_score:.1f}\nVisits: {node.visited}",
avg_value=node.avg_value(),
dev_score=dev_score,
visits=node.visited,
instruction=instruction,
)
# Stopping condition: if the node has no children, return
if not node.children:
return
# Recursively create all child nodes
for i, child in enumerate(node.children):
child_id = f"{parent_id}-{i}"
graph.add_edge(parent_id, child_id)
build_tree_recursive(graph, child_id, child)

View file

@ -0,0 +1,23 @@
import networkx as nx
from expo.evaluation.visualize_mcts import build_tree_recursive, visualize_tree
from expo.MCTS import MCTS, create_initial_state, initialize_di_root_node
from expo.run_experiment import get_args
from expo.utils import DATA_CONFIG
if __name__ == "__main__":
args = get_args()
data_config = DATA_CONFIG
state = create_initial_state(args.task, 0, data_config, args=args)
role, node = initialize_di_root_node(state)
mcts = MCTS(
root_node=node,
max_depth=5,
use_fixed_insights=False,
)
mcts.load_tree()
root = mcts.root_node
G = nx.DiGraph()
tree = build_tree_recursive(G, "0", root)
visualize_tree(tree, save_path="../results/tree.png")