diff --git a/expo/evaluation/visualize_mcts.py b/expo/evaluation/visualize_mcts.py index d310036c0..e429789fd 100644 --- a/expo/evaluation/visualize_mcts.py +++ b/expo/evaluation/visualize_mcts.py @@ -1,5 +1,8 @@ import textwrap +import matplotlib.pyplot as plt +import networkx as nx + from expo.MCTS import Node NODE_TEMPLATE = """\ @@ -11,6 +14,9 @@ Score: {score}, Visits: {num_visits} """ +NODE_SIZE = 12000 +NODE_FONT_SIZE = 18 + def get_role_plans(role): plans = role.planner.plan.tasks @@ -42,7 +48,7 @@ def get_tree_text(node: Node): id=node_id, plans=instruct_plans_text, simulated=simulated, score=score, num_visits=num_visits ) - def visualize_tree(node, depth=0, previous_plans=None): + def visualize_tree_text(node, depth=0, previous_plans=None): text = "" if node is not None: text += visualize_node(node, previous_plans) @@ -50,10 +56,106 @@ def get_tree_text(node: Node): code_set.update({task.instruction for task in role.planner.plan.tasks}) previous_plans = get_role_plans(role) for child in node.children: - text += textwrap.indent(visualize_tree(child, depth + 1, previous_plans), "\t") + text += textwrap.indent(visualize_tree_text(child, depth + 1, previous_plans), "\t") return text num_simulations = node.visited text = f"Number of simulations: {num_simulations}\n" - text += visualize_tree(node) + text += visualize_tree_text(node) return text, len(code_set) + + +def get_node_color(node): + if node["visits"] == 0: + return "#D3D3D3" + else: + # The higher the avg_value, the more intense the color + # avg_value is between 0 and 1 + avg_value = node["avg_value"] + # Convert avg_value to a color ranging from red (low) to green (high) + red = int(255 * (1 - avg_value)) + green = int(255 * avg_value) + return f"#{red:02X}{green:02X}00" + + +def visualize_tree(graph, save_path=""): + # Use a hierarchical layout for tree-like visualization + pos = nx.spring_layout(graph, k=0.9, iterations=50) + + plt.figure(figsize=(30, 20)) # Further increase figure size for better visibility + + # Calculate node levels + root = "0" + levels = nx.single_source_shortest_path_length(graph, root) + max_level = max(levels.values()) + + # Adjust y-coordinates based on levels and x-coordinates to prevent overlap + nodes_by_level = {} + for node, level in levels.items(): + if level not in nodes_by_level: + nodes_by_level[level] = [] + nodes_by_level[level].append(node) + + for level, nodes in nodes_by_level.items(): + y = 1 - level / max_level + x_step = 1.0 / (len(nodes) + 1) + for i, node in enumerate(sorted(nodes)): + pos[node] = ((i + 1) * x_step, y) + + # Draw edges + nx.draw_networkx_edges(graph, pos, edge_color="gray", arrows=True, arrowsize=40, width=3) + + # Draw nodes + node_colors = [get_node_color(graph.nodes[node]) for node in graph.nodes] + nx.draw_networkx_nodes(graph, pos, node_size=NODE_SIZE, node_color=node_colors) + + # Add labels to nodes + labels = nx.get_node_attributes(graph, "label") + nx.draw_networkx_labels(graph, pos, labels, font_size=NODE_FONT_SIZE) + + # Add instructions to the right side of nodes + instructions = nx.get_node_attributes(graph, "instruction") + for node, (x, y) in pos.items(): + wrapped_text = textwrap.fill(instructions[node], width=30) # Adjust width as needed + plt.text(x + 0.05, y, wrapped_text, fontsize=15, ha="left", va="center") + + plt.title("MCTS Tree Visualization", fontsize=40) + plt.axis("off") # Turn off axis + plt.tight_layout() + if save_path: + plt.savefig(save_path) + plt.show() + + +def build_tree_recursive(graph, parent_id, node, start_task_id=2): + """ + Recursively builds the entire tree starting from the root node. + Adds nodes and edges to the NetworkX graph. + """ + role = node.load_role() + depth = node.get_depth() + if depth == 0: + instruction = "\n\n".join([role.planner.plan.tasks[i].instruction for i in range(start_task_id)]) + else: + instruction = role.planner.plan.tasks[depth + start_task_id - 1].instruction + print(instruction) + # Add the current node with attributes to the graph + dev_score = node.raw_reward.get("dev_score", 0) * 100 + avg_score = node.avg_value() * 100 + graph.add_node( + parent_id, + label=f"{node.id}\nAvg: {avg_score:.1f}\nScore: {dev_score:.1f}\nVisits: {node.visited}", + avg_value=node.avg_value(), + dev_score=dev_score, + visits=node.visited, + instruction=instruction, + ) + # Stopping condition: if the node has no children, return + if not node.children: + return + + # Recursively create all child nodes + for i, child in enumerate(node.children): + child_id = f"{parent_id}-{i}" + graph.add_edge(parent_id, child_id) + build_tree_recursive(graph, child_id, child) diff --git a/expo/scripts/visualize_experiment.py b/expo/scripts/visualize_experiment.py new file mode 100644 index 000000000..c06b1eeab --- /dev/null +++ b/expo/scripts/visualize_experiment.py @@ -0,0 +1,23 @@ +import networkx as nx + +from expo.evaluation.visualize_mcts import build_tree_recursive, visualize_tree +from expo.MCTS import MCTS, create_initial_state, initialize_di_root_node +from expo.run_experiment import get_args +from expo.utils import DATA_CONFIG + +if __name__ == "__main__": + args = get_args() + data_config = DATA_CONFIG + state = create_initial_state(args.task, 0, data_config, args=args) + role, node = initialize_di_root_node(state) + mcts = MCTS( + root_node=node, + max_depth=5, + use_fixed_insights=False, + ) + + mcts.load_tree() + root = mcts.root_node + G = nx.DiGraph() + tree = build_tree_recursive(G, "0", root) + visualize_tree(tree, save_path="../results/tree.png")