rename aug to rs

This commit is contained in:
Cyzus Chi 2024-10-25 23:07:14 +08:00
parent a62ae88187
commit 76029782cc
3 changed files with 15 additions and 15 deletions

View file

@ -77,10 +77,10 @@ #### Ablation Study
**DI RandomSearch**
- Single insight
`python run_experiment.py --exp_mode aug --task titanic --aug_mode single`
`python run_experiment.py --exp_mode rs --task titanic --rs_mode single`
- Set insight
`python run_experiment.py --exp_mode aug --task titanic --aug_mode set`
`python run_experiment.py --exp_mode rs --task titanic --rs_mode set`
## 4. Evaluation

View file

@ -10,8 +10,8 @@ When doing the tasks, you can refer to the insights below:
"""
class AugExperimenter(Experimenter):
result_path: str = "results/aug"
class RandomSearchExperimenter(Experimenter):
result_path: str = "results/random_search"
async def run_experiment(self):
# state = create_initial_state(self.args.task, start_task_id=1, data_config=self.data_config, low_is_better=self.args.low_is_better, name="")
@ -20,17 +20,17 @@ class AugExperimenter(Experimenter):
exp_pool = InstructionGenerator.load_analysis_pool(
exp_pool_path, use_fixed_insights=self.args.use_fixed_insights
)
if self.args.aug_mode == "single":
if self.args.rs_mode == "single":
exps = InstructionGenerator._random_sample(exp_pool, self.args.num_experiments)
exps = [exp["Analysis"] for exp in exps]
elif self.args.aug_mode == "set":
elif self.args.rs_mode == "set":
exps = []
for i in range(self.args.num_experiments):
exp_set = InstructionGenerator.sample_instruction_set(exp_pool)
exp_set_text = "\n".join([f"{exp['task_id']}: {exp['Analysis']}" for exp in exp_set])
exps.append(exp_set_text)
else:
raise ValueError(f"Invalid mode: {self.args.aug_mode}")
raise ValueError(f"Invalid mode: {self.args.rs_mode}")
results = []
for i in range(self.args.num_experiments):
@ -45,7 +45,7 @@ class AugExperimenter(Experimenter):
{
"idx": i,
"score_dict": score_dict,
"aug_mode": self.args.aug_mode,
"rs_mode": self.args.rs_mode,
"insights": exps[i],
"user_requirement": requirement,
"args": vars(self.args),

View file

@ -2,7 +2,7 @@ import argparse
import asyncio
from metagpt.ext.sela.data.custom_task import get_mle_is_lower_better, get_mle_task_id
from metagpt.ext.sela.experimenter.aug import AugExperimenter
from metagpt.ext.sela.experimenter.random_search import RandomSearchExperimenter
from metagpt.ext.sela.experimenter.autogluon import GluonExperimenter
from metagpt.ext.sela.experimenter.autosklearn import AutoSklearnExperimenter
from metagpt.ext.sela.experimenter.custom import CustomExperimenter
@ -17,12 +17,12 @@ def get_args(cmd=True):
"--exp_mode",
type=str,
default="mcts",
choices=["mcts", "aug", "base", "custom", "greedy", "autogluon", "random", "autosklearn"],
choices=["mcts", "rs", "base", "custom", "greedy", "autogluon", "random", "autosklearn"],
)
parser.add_argument("--role_timeout", type=int, default=1000)
get_di_args(parser)
get_mcts_args(parser)
get_aug_exp_args(parser)
get_rs_exp_args(parser)
if cmd:
args = parser.parse_args()
else:
@ -56,8 +56,8 @@ def get_mcts_args(parser):
parser.add_argument("--max_depth", type=int, default=4)
def get_aug_exp_args(parser):
parser.add_argument("--aug_mode", type=str, default="single", choices=["single", "set"])
def get_rs_exp_args(parser):
parser.add_argument("--rs_mode", type=str, default="single", choices=["single", "set"])
parser.add_argument("--is_multimodal", action="store_true", help="Specify if the model is multi-modal")
@ -79,8 +79,8 @@ async def main(args):
experimenter = MCTSExperimenter(args, tree_mode="greedy")
elif args.exp_mode == "random":
experimenter = MCTSExperimenter(args, tree_mode="random")
elif args.exp_mode == "aug":
experimenter = AugExperimenter(args)
elif args.exp_mode == "rs":
experimenter = RandomSearchExperimenter(args)
elif args.exp_mode == "base":
experimenter = Experimenter(args)
elif args.exp_mode == "autogluon":