fix bug - make rollout more consistent

This commit is contained in:
Yizhou Chi 2024-08-30 20:35:17 +08:00
parent 32759f031c
commit ae6a195750
7 changed files with 50 additions and 10 deletions

View file

@ -284,7 +284,7 @@ class MCTS():
return self.root_node.visited
async def search(self, task, data_config, name,
rollout=3, load_tree=False, low_is_better=False, reflection=False):
rollouts, load_tree=False, low_is_better=False, reflection=False):
role, root = initialize_di_root_node(task, data_config, low_is_better=low_is_better, reflection=reflection, name=name)
self.root_node = root
@ -292,7 +292,12 @@ class MCTS():
if load_tree:
tree_loaded = self.load_tree()
mcts_logger.log("MCTS", f"Number of simulations: {self.get_num_simulations()}")
if not tree_loaded:
rollouts -= 2
if rollouts < 0:
raise ValueError("Rollouts must be greater than 2 if there is no tree to load")
self.children[root] = []
reward = await self.simulate(root, role)
self.backpropagate(root, reward)
@ -307,7 +312,7 @@ class MCTS():
else:
root = self.root_node
# 后续迭代使用UCT进行选择expand并模拟和反向传播
for _ in range(rollout): # 迭代次数
for _ in range(rollouts): # 迭代次数
mcts_logger.log("MCTS", f"开始第{_+1}次迭代")
leaf = self.select(root)
if leaf.is_terminal():

32
expo/README.md Normal file
View file

@ -0,0 +1,32 @@
# Expo
## Instruction
- 下载数据集https://deepwisdom.feishu.cn/drive/folder/RVyofv9cvlvtxKdddt2cyn3BnTc?from=from_copylink
## Examples
### Run Base DI
`python run_experiment.py --exp_mode base --task titanic`
### Run DI RandExp
- Single insight
`python run_experiment.py --exp_mode aug --task titanic --aug_mode single`
- Set insight
`python run_experiment.py --exp_mode aug --task titanic --aug_mode set`
### Run DI MCTS
`python run_experiment.py --exp_mode mcts --task titanic --rollout 5`
`python run_experiment.py --exp_mode mcts --task househouse_prices --rollout 5 --low_is_better`

View file

@ -17,7 +17,10 @@ Report {metric} on the eval data. Do not plot or make any visualizations.
TASK_PROMPT = """\
# User requirement
{user_requirement}
**Attention** Please do not leak the target label in any form during training.
**Attention**
1. Please do not leak the target label in any form during training.
2. Dev and Test sets do not have the target column.
3. You should perform transformations on all sets at the same step.
## Saving Dev and Test Predictions
Save the prediction results of BOTH the dev set and test set in `dev_predictions.csv` and `test_predictions.csv` respectively in the output directory.

View file

@ -12,7 +12,7 @@ class MCTSExperimenter(Experimenter):
low_is_better=self.args.low_is_better,
load_tree=self.args.load_tree,
reflection=self.args.reflection,
rollout=self.args.rollout,
rollouts=self.args.rollouts,
name=self.args.name)
best_node = best_nodes["global_best"]
dev_best_node = best_nodes["dev_best"]

View file

@ -16,12 +16,12 @@ def get_args():
def get_mcts_args(parser):
parser.add_argument("--load_tree", dest="load_tree", action="store_true")
parser.add_argument("--no_load_tree", dest="load_tree", action="store_false")
parser.set_defaults(load_tree=True)
parser.add_argument("--rollout", type=int, default=3)
parser.set_defaults(load_tree=False)
parser.add_argument("--rollouts", type=int, default=3)
def get_aug_exp_args(parser):
parser.add_argument("--aug_mode", type=str, default="single", choices=["single", "set"])
parser.add_argument("--num_experiments", type=int, default=2)
parser.add_argument("--num_experiments", type=int, default=1)
def get_di_args(parser):

View file

@ -18,7 +18,7 @@ def get_args():
parser.add_argument("--reflection", dest="reflection", action="store_true")
parser.add_argument("--no_reflection", dest="reflection", action="store_false")
parser.set_defaults(reflection=True)
parser.add_argument("--rollout", type=int, default=3)
parser.add_argument("--rollouts", type=int, default=3)
parser.add_argument("--name", type=str, default="")
return parser.parse_args()
@ -37,7 +37,7 @@ if __name__ == "__main__":
mcts = MCTS(root_node=None, max_depth=5)
best_nodes = asyncio.run(mcts.search(args.task, data_config,
low_is_better=args.low_is_better, load_tree=args.load_tree,
reflection=args.reflection, rollout=args.rollout, name=args.name))
reflection=args.reflection, rollouts=args.rollouts, name=args.name))
best_node = best_nodes["global_best"]
dev_best_node = best_nodes["dev_best"]
text, num_generated_codes = get_tree_text(mcts.root_node)

View file

@ -25,7 +25,7 @@ The current task is about feature engineering. when performing it, please adhere
- Use available feature engineering tools if they are potential impactful.
- Avoid creating redundant or excessively numerous features in one step.
- Exclude ID columns from feature generation and remove them.
- Each feature engineering operation performed on the train set must also applies to the test separately at the same time.
- Each feature engineering operation performed on the train set must also applies to the dev/test separately at the same time.
- Avoid using the label column to create features, except for cat encoding.
- Use the data from previous task result if exist, do not mock or reload data yourself.
- Always copy the DataFrame before processing it and use the copy to process.