diff --git a/expo/experimenter/aug.py b/expo/experimenter/aug.py index e57d024bd..ffe0d04c5 100644 --- a/expo/experimenter/aug.py +++ b/expo/experimenter/aug.py @@ -24,9 +24,11 @@ class AugExperimenter(Experimenter): exps = InstructionGenerator._random_sample(exp_pool, self.args.num_experiments) exps = [exp["Analysis"] for exp in exps] elif self.args.aug_mode == "set": - exp_set = InstructionGenerator.sample_instruction_set(exp_pool) - exp_set_text = "\n".join([f"{exp['task_id']}: {exp['Analysis']}" for exp in exp_set]) - exps = [exp_set_text] * self.args.num_experiments + exps = [] + for i in range(self.args.num_experiments): + exp_set = InstructionGenerator.sample_instruction_set(exp_pool) + exp_set_text = "\n".join([f"{exp['task_id']}: {exp['Analysis']}" for exp in exp_set]) + exps.append(exp_set_text) else: raise ValueError(f"Invalid mode: {self.args.aug_mode}")