mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-05 05:42:37 +02:00
add random tree search
This commit is contained in:
parent
d34a482faf
commit
b776c7309b
4 changed files with 23 additions and 7 deletions
|
|
@ -1,3 +1,5 @@
|
|||
import random
|
||||
|
||||
from expo.MCTS import MCTS
|
||||
|
||||
|
||||
|
|
@ -7,3 +9,11 @@ class Greedy(MCTS):
|
|||
return self.root_node
|
||||
all_children = [child for children in self.children.values() for child in children]
|
||||
return max(all_children, key=lambda x: x.normalized_reward.get("dev_score", 0))
|
||||
|
||||
|
||||
class Random(MCTS):
|
||||
def best_child(self):
|
||||
if len(self.children) == 0:
|
||||
return self.root_node
|
||||
all_children = [child for children in self.children.values() for child in children]
|
||||
return random.choice(all_children)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ DI_INSTRUCTION = """\
|
|||
2. Test set does not have the target column.
|
||||
3. You should perform transformations on train, dev, and test sets at the same time (it's a good idea to define functions for this and avoid code repetition).
|
||||
4. If labels are transformed during training, they should be transformed back to the original format before saving the predictions.
|
||||
5. You could split the training set further to make cross-validation and hyperparameter tuning.
|
||||
5. You could utilize dev set to improve the model.
|
||||
|
||||
## Saving Dev and Test Predictions
|
||||
1. Save the prediction results of BOTH the dev set and test set in `dev_predictions.csv` and `test_predictions.csv` respectively in the output directory.
|
||||
|
|
|
|||
|
|
@ -1,19 +1,21 @@
|
|||
from expo.evaluation.visualize_mcts import get_tree_text
|
||||
from expo.experimenter.experimenter import Experimenter
|
||||
from expo.Greedy import Greedy
|
||||
from expo.Greedy import Greedy, Random
|
||||
from expo.MCTS import MCTS
|
||||
|
||||
|
||||
class MCTSExperimenter(Experimenter):
|
||||
result_path: str = "results/mcts"
|
||||
|
||||
def __init__(self, args, greedy=False, **kwargs):
|
||||
def __init__(self, args, tree_mode=None, **kwargs):
|
||||
super().__init__(args, **kwargs)
|
||||
self.greedy = greedy
|
||||
self.tree_mode = tree_mode
|
||||
|
||||
async def run_experiment(self):
|
||||
if self.greedy:
|
||||
if self.tree_mode == "greedy":
|
||||
mcts = Greedy(root_node=None, max_depth=5)
|
||||
elif self.tree_mode == "random":
|
||||
mcts = Random(root_node=None, max_depth=5)
|
||||
else:
|
||||
mcts = MCTS(root_node=None, max_depth=5)
|
||||
best_nodes = await mcts.search(
|
||||
|
|
|
|||
|
|
@ -10,7 +10,9 @@ from expo.experimenter.mcts import MCTSExperimenter
|
|||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--name", type=str, default="")
|
||||
parser.add_argument("--exp_mode", type=str, default="mcts", choices=["mcts", "aug", "base", "custom", "greedy"])
|
||||
parser.add_argument(
|
||||
"--exp_mode", type=str, default="mcts", choices=["mcts", "aug", "base", "custom", "greedy", "random"]
|
||||
)
|
||||
get_di_args(parser)
|
||||
get_mcts_args(parser)
|
||||
get_aug_exp_args(parser)
|
||||
|
|
@ -42,7 +44,9 @@ async def main(args):
|
|||
if args.exp_mode == "mcts":
|
||||
experimenter = MCTSExperimenter(args)
|
||||
elif args.exp_mode == "greedy":
|
||||
experimenter = MCTSExperimenter(args, greedy=True)
|
||||
experimenter = MCTSExperimenter(args, tree_mode="greedy")
|
||||
elif args.exp_mode == "random":
|
||||
experimenter = MCTSExperimenter(args, tree_mode="random")
|
||||
elif args.exp_mode == "aug":
|
||||
experimenter = AugExperimenter(args)
|
||||
elif args.exp_mode == "base":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue