add tree visualization script and function

This commit is contained in:
Yizhou Chi 2024-10-17 10:27:33 +08:00
parent a46f575361
commit 0f01c07b83
2 changed files with 128 additions and 3 deletions

View file

@ -0,0 +1,23 @@
import networkx as nx
from expo.evaluation.visualize_mcts import build_tree_recursive, visualize_tree
from expo.MCTS import MCTS, create_initial_state, initialize_di_root_node
from expo.run_experiment import get_args
from expo.utils import DATA_CONFIG
if __name__ == "__main__":
args = get_args()
data_config = DATA_CONFIG
state = create_initial_state(args.task, 0, data_config, args=args)
role, node = initialize_di_root_node(state)
mcts = MCTS(
root_node=node,
max_depth=5,
use_fixed_insights=False,
)
mcts.load_tree()
root = mcts.root_node
G = nx.DiGraph()
tree = build_tree_recursive(G, "0", root)
visualize_tree(tree, save_path="../results/tree.png")