rename expo folder to sela

This commit is contained in:
Cyzus Chi 2024-10-22 21:33:31 +08:00
parent 4bed19b931
commit 7c5b29de63
33 changed files with 53 additions and 53 deletions

15
sela/scripts/run_cls.sh Normal file
View file

@ -0,0 +1,15 @@
#!/bin/bash
tasks=("smoker-status" "software-defects" "jasmine" "credit-g" "Click_prediction_small" "kick" "kc1" "titanic" "icr" "wine-quality-white" "mfeat-factors" "segment" "GesturePhaseSegmentationProcessed")
for i in {1..3}
do
for task in "${tasks[@]}"; do
echo "Running experiment for task: $task"
python run_experiment.py --exp_mode mcts --task "$task" --rollouts 10 --special_instruction stacking
echo "Experiment for task $task completed."
done
done
echo "All experiments completed."

View file

@ -0,0 +1,13 @@
#!/bin/bash
tasks=("banking77" "gnad10" "sms_spam" "oxford-iiit-pet" "stanford_cars" "fashion_mnist" )
for i in {1..3}
do
for task in "${tasks[@]}"; do
echo "Running experiment for task: $task"
python run_experiment.py --exp_mode mcts --task "$task" --rollouts 10
echo "Experiment for task $task completed."
done
done
echo "All experiments completed."

14
sela/scripts/run_reg.sh Normal file
View file

@ -0,0 +1,14 @@
#!/bin/bash
tasks=("concrete-strength" "Moneyball" "colleges" "SAT11-HAND-runtime-regression" "diamonds" "boston" "house-prices")
for i in {1..3}
do
for task in "${tasks[@]}"; do
echo "Running experiment for task: $task"
python run_experiment.py --exp_mode mcts --task "$task" --rollouts 10 --low_is_better --special_instruction stacking
echo "Experiment for task $task completed."
done
done
echo "All experiments completed."

View file

@ -0,0 +1,25 @@
import networkx as nx
from sela.evaluation.visualize_mcts import build_tree_recursive, visualize_tree
from sela.MCTS import MCTS, create_initial_state, initialize_di_root_node
from sela.run_experiment import get_args
from sela.utils import DATA_CONFIG
if __name__ == "__main__":
args = get_args()
data_config = DATA_CONFIG
state = create_initial_state(args.task, 0, data_config, args=args)
role, node = initialize_di_root_node(state)
mcts = MCTS(
root_node=node,
max_depth=5,
use_fixed_insights=False,
)
mcts.load_tree()
mcts.load_node_order()
root = mcts.root_node
node_order = mcts.node_order
G = nx.DiGraph()
build_tree_recursive(G, "0", root, node_order)
visualize_tree(G, save_path=f"results/{args.task}-tree.png")