mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-11 08:42:38 +02:00
fix bug - make rollout more consistent
This commit is contained in:
parent
32759f031c
commit
ae6a195750
7 changed files with 50 additions and 10 deletions
|
|
@ -284,7 +284,7 @@ class MCTS():
|
|||
return self.root_node.visited
|
||||
|
||||
async def search(self, task, data_config, name,
|
||||
rollout=3, load_tree=False, low_is_better=False, reflection=False):
|
||||
rollouts, load_tree=False, low_is_better=False, reflection=False):
|
||||
|
||||
role, root = initialize_di_root_node(task, data_config, low_is_better=low_is_better, reflection=reflection, name=name)
|
||||
self.root_node = root
|
||||
|
|
@ -292,7 +292,12 @@ class MCTS():
|
|||
if load_tree:
|
||||
tree_loaded = self.load_tree()
|
||||
mcts_logger.log("MCTS", f"Number of simulations: {self.get_num_simulations()}")
|
||||
|
||||
|
||||
if not tree_loaded:
|
||||
rollouts -= 2
|
||||
if rollouts < 0:
|
||||
raise ValueError("Rollouts must be greater than 2 if there is no tree to load")
|
||||
self.children[root] = []
|
||||
reward = await self.simulate(root, role)
|
||||
self.backpropagate(root, reward)
|
||||
|
|
@ -307,7 +312,7 @@ class MCTS():
|
|||
else:
|
||||
root = self.root_node
|
||||
# 后续迭代:使用UCT进行选择,expand并模拟和反向传播
|
||||
for _ in range(rollout): # 迭代次数
|
||||
for _ in range(rollouts): # 迭代次数
|
||||
mcts_logger.log("MCTS", f"开始第{_+1}次迭代")
|
||||
leaf = self.select(root)
|
||||
if leaf.is_terminal():
|
||||
|
|
|
|||
32
expo/README.md
Normal file
32
expo/README.md
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
# Expo
|
||||
|
||||
|
||||
## Instruction
|
||||
|
||||
- 下载数据集:https://deepwisdom.feishu.cn/drive/folder/RVyofv9cvlvtxKdddt2cyn3BnTc?from=from_copylink
|
||||
|
||||
|
||||
## Examples
|
||||
|
||||
### Run Base DI
|
||||
|
||||
`python run_experiment.py --exp_mode base --task titanic`
|
||||
|
||||
### Run DI RandExp
|
||||
|
||||
- Single insight
|
||||
`python run_experiment.py --exp_mode aug --task titanic --aug_mode single`
|
||||
|
||||
- Set insight
|
||||
`python run_experiment.py --exp_mode aug --task titanic --aug_mode set`
|
||||
|
||||
|
||||
|
||||
### Run DI MCTS
|
||||
`python run_experiment.py --exp_mode mcts --task titanic --rollout 5`
|
||||
|
||||
`python run_experiment.py --exp_mode mcts --task househouse_prices --rollout 5 --low_is_better`
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -17,7 +17,10 @@ Report {metric} on the eval data. Do not plot or make any visualizations.
|
|||
TASK_PROMPT = """\
|
||||
# User requirement
|
||||
{user_requirement}
|
||||
**Attention** Please do not leak the target label in any form during training.
|
||||
**Attention**
|
||||
1. Please do not leak the target label in any form during training.
|
||||
2. Dev and Test sets do not have the target column.
|
||||
3. You should perform transformations on all sets at the same step.
|
||||
|
||||
## Saving Dev and Test Predictions
|
||||
Save the prediction results of BOTH the dev set and test set in `dev_predictions.csv` and `test_predictions.csv` respectively in the output directory.
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ class MCTSExperimenter(Experimenter):
|
|||
low_is_better=self.args.low_is_better,
|
||||
load_tree=self.args.load_tree,
|
||||
reflection=self.args.reflection,
|
||||
rollout=self.args.rollout,
|
||||
rollouts=self.args.rollouts,
|
||||
name=self.args.name)
|
||||
best_node = best_nodes["global_best"]
|
||||
dev_best_node = best_nodes["dev_best"]
|
||||
|
|
|
|||
|
|
@ -16,12 +16,12 @@ def get_args():
|
|||
def get_mcts_args(parser):
|
||||
parser.add_argument("--load_tree", dest="load_tree", action="store_true")
|
||||
parser.add_argument("--no_load_tree", dest="load_tree", action="store_false")
|
||||
parser.set_defaults(load_tree=True)
|
||||
parser.add_argument("--rollout", type=int, default=3)
|
||||
parser.set_defaults(load_tree=False)
|
||||
parser.add_argument("--rollouts", type=int, default=3)
|
||||
|
||||
def get_aug_exp_args(parser):
|
||||
parser.add_argument("--aug_mode", type=str, default="single", choices=["single", "set"])
|
||||
parser.add_argument("--num_experiments", type=int, default=2)
|
||||
parser.add_argument("--num_experiments", type=int, default=1)
|
||||
|
||||
|
||||
def get_di_args(parser):
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ def get_args():
|
|||
parser.add_argument("--reflection", dest="reflection", action="store_true")
|
||||
parser.add_argument("--no_reflection", dest="reflection", action="store_false")
|
||||
parser.set_defaults(reflection=True)
|
||||
parser.add_argument("--rollout", type=int, default=3)
|
||||
parser.add_argument("--rollouts", type=int, default=3)
|
||||
parser.add_argument("--name", type=str, default="")
|
||||
return parser.parse_args()
|
||||
|
||||
|
|
@ -37,7 +37,7 @@ if __name__ == "__main__":
|
|||
mcts = MCTS(root_node=None, max_depth=5)
|
||||
best_nodes = asyncio.run(mcts.search(args.task, data_config,
|
||||
low_is_better=args.low_is_better, load_tree=args.load_tree,
|
||||
reflection=args.reflection, rollout=args.rollout, name=args.name))
|
||||
reflection=args.reflection, rollouts=args.rollouts, name=args.name))
|
||||
best_node = best_nodes["global_best"]
|
||||
dev_best_node = best_nodes["dev_best"]
|
||||
text, num_generated_codes = get_tree_text(mcts.root_node)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ The current task is about feature engineering. when performing it, please adhere
|
|||
- Use available feature engineering tools if they are potential impactful.
|
||||
- Avoid creating redundant or excessively numerous features in one step.
|
||||
- Exclude ID columns from feature generation and remove them.
|
||||
- Each feature engineering operation performed on the train set must also applies to the test separately at the same time.
|
||||
- Each feature engineering operation performed on the train set must also applies to the dev/test separately at the same time.
|
||||
- Avoid using the label column to create features, except for cat encoding.
|
||||
- Use the data from previous task result if exist, do not mock or reload data yourself.
|
||||
- Always copy the DataFrame before processing it and use the copy to process.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue