mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-02 14:45:17 +02:00
add tree visualization script and function
This commit is contained in:
parent
a46f575361
commit
0f01c07b83
2 changed files with 128 additions and 3 deletions
|
|
@ -1,5 +1,8 @@
|
|||
import textwrap
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
|
||||
from expo.MCTS import Node
|
||||
|
||||
NODE_TEMPLATE = """\
|
||||
|
|
@ -11,6 +14,9 @@ Score: {score}, Visits: {num_visits}
|
|||
|
||||
"""
|
||||
|
||||
NODE_SIZE = 12000
|
||||
NODE_FONT_SIZE = 18
|
||||
|
||||
|
||||
def get_role_plans(role):
|
||||
plans = role.planner.plan.tasks
|
||||
|
|
@ -42,7 +48,7 @@ def get_tree_text(node: Node):
|
|||
id=node_id, plans=instruct_plans_text, simulated=simulated, score=score, num_visits=num_visits
|
||||
)
|
||||
|
||||
def visualize_tree(node, depth=0, previous_plans=None):
|
||||
def visualize_tree_text(node, depth=0, previous_plans=None):
|
||||
text = ""
|
||||
if node is not None:
|
||||
text += visualize_node(node, previous_plans)
|
||||
|
|
@ -50,10 +56,106 @@ def get_tree_text(node: Node):
|
|||
code_set.update({task.instruction for task in role.planner.plan.tasks})
|
||||
previous_plans = get_role_plans(role)
|
||||
for child in node.children:
|
||||
text += textwrap.indent(visualize_tree(child, depth + 1, previous_plans), "\t")
|
||||
text += textwrap.indent(visualize_tree_text(child, depth + 1, previous_plans), "\t")
|
||||
return text
|
||||
|
||||
num_simulations = node.visited
|
||||
text = f"Number of simulations: {num_simulations}\n"
|
||||
text += visualize_tree(node)
|
||||
text += visualize_tree_text(node)
|
||||
return text, len(code_set)
|
||||
|
||||
|
||||
def get_node_color(node):
|
||||
if node["visits"] == 0:
|
||||
return "#D3D3D3"
|
||||
else:
|
||||
# The higher the avg_value, the more intense the color
|
||||
# avg_value is between 0 and 1
|
||||
avg_value = node["avg_value"]
|
||||
# Convert avg_value to a color ranging from red (low) to green (high)
|
||||
red = int(255 * (1 - avg_value))
|
||||
green = int(255 * avg_value)
|
||||
return f"#{red:02X}{green:02X}00"
|
||||
|
||||
|
||||
def visualize_tree(graph, save_path=""):
|
||||
# Use a hierarchical layout for tree-like visualization
|
||||
pos = nx.spring_layout(graph, k=0.9, iterations=50)
|
||||
|
||||
plt.figure(figsize=(30, 20)) # Further increase figure size for better visibility
|
||||
|
||||
# Calculate node levels
|
||||
root = "0"
|
||||
levels = nx.single_source_shortest_path_length(graph, root)
|
||||
max_level = max(levels.values())
|
||||
|
||||
# Adjust y-coordinates based on levels and x-coordinates to prevent overlap
|
||||
nodes_by_level = {}
|
||||
for node, level in levels.items():
|
||||
if level not in nodes_by_level:
|
||||
nodes_by_level[level] = []
|
||||
nodes_by_level[level].append(node)
|
||||
|
||||
for level, nodes in nodes_by_level.items():
|
||||
y = 1 - level / max_level
|
||||
x_step = 1.0 / (len(nodes) + 1)
|
||||
for i, node in enumerate(sorted(nodes)):
|
||||
pos[node] = ((i + 1) * x_step, y)
|
||||
|
||||
# Draw edges
|
||||
nx.draw_networkx_edges(graph, pos, edge_color="gray", arrows=True, arrowsize=40, width=3)
|
||||
|
||||
# Draw nodes
|
||||
node_colors = [get_node_color(graph.nodes[node]) for node in graph.nodes]
|
||||
nx.draw_networkx_nodes(graph, pos, node_size=NODE_SIZE, node_color=node_colors)
|
||||
|
||||
# Add labels to nodes
|
||||
labels = nx.get_node_attributes(graph, "label")
|
||||
nx.draw_networkx_labels(graph, pos, labels, font_size=NODE_FONT_SIZE)
|
||||
|
||||
# Add instructions to the right side of nodes
|
||||
instructions = nx.get_node_attributes(graph, "instruction")
|
||||
for node, (x, y) in pos.items():
|
||||
wrapped_text = textwrap.fill(instructions[node], width=30) # Adjust width as needed
|
||||
plt.text(x + 0.05, y, wrapped_text, fontsize=15, ha="left", va="center")
|
||||
|
||||
plt.title("MCTS Tree Visualization", fontsize=40)
|
||||
plt.axis("off") # Turn off axis
|
||||
plt.tight_layout()
|
||||
if save_path:
|
||||
plt.savefig(save_path)
|
||||
plt.show()
|
||||
|
||||
|
||||
def build_tree_recursive(graph, parent_id, node, start_task_id=2):
|
||||
"""
|
||||
Recursively builds the entire tree starting from the root node.
|
||||
Adds nodes and edges to the NetworkX graph.
|
||||
"""
|
||||
role = node.load_role()
|
||||
depth = node.get_depth()
|
||||
if depth == 0:
|
||||
instruction = "\n\n".join([role.planner.plan.tasks[i].instruction for i in range(start_task_id)])
|
||||
else:
|
||||
instruction = role.planner.plan.tasks[depth + start_task_id - 1].instruction
|
||||
print(instruction)
|
||||
# Add the current node with attributes to the graph
|
||||
dev_score = node.raw_reward.get("dev_score", 0) * 100
|
||||
avg_score = node.avg_value() * 100
|
||||
graph.add_node(
|
||||
parent_id,
|
||||
label=f"{node.id}\nAvg: {avg_score:.1f}\nScore: {dev_score:.1f}\nVisits: {node.visited}",
|
||||
avg_value=node.avg_value(),
|
||||
dev_score=dev_score,
|
||||
visits=node.visited,
|
||||
instruction=instruction,
|
||||
)
|
||||
# Stopping condition: if the node has no children, return
|
||||
if not node.children:
|
||||
return
|
||||
|
||||
# Recursively create all child nodes
|
||||
for i, child in enumerate(node.children):
|
||||
child_id = f"{parent_id}-{i}"
|
||||
graph.add_edge(parent_id, child_id)
|
||||
build_tree_recursive(graph, child_id, child)
|
||||
|
|
|
|||
23
expo/scripts/visualize_experiment.py
Normal file
23
expo/scripts/visualize_experiment.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
import networkx as nx
|
||||
|
||||
from expo.evaluation.visualize_mcts import build_tree_recursive, visualize_tree
|
||||
from expo.MCTS import MCTS, create_initial_state, initialize_di_root_node
|
||||
from expo.run_experiment import get_args
|
||||
from expo.utils import DATA_CONFIG
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
data_config = DATA_CONFIG
|
||||
state = create_initial_state(args.task, 0, data_config, args=args)
|
||||
role, node = initialize_di_root_node(state)
|
||||
mcts = MCTS(
|
||||
root_node=node,
|
||||
max_depth=5,
|
||||
use_fixed_insights=False,
|
||||
)
|
||||
|
||||
mcts.load_tree()
|
||||
root = mcts.root_node
|
||||
G = nx.DiGraph()
|
||||
tree = build_tree_recursive(G, "0", root)
|
||||
visualize_tree(tree, save_path="../results/tree.png")
|
||||
Loading…
Add table
Add a link
Reference in a new issue