diff --git a/expo/data/dataset.py b/expo/data/dataset.py index 28bd26d2e..8ad6f2854 100644 --- a/expo/data/dataset.py +++ b/expo/data/dataset.py @@ -20,12 +20,23 @@ USE_AG = """ 7. Please use autogluon for model training with presets='medium_quality', time_limit=None, give dev dataset to tuning_data, and use right eval_metric. """ +TEXT_MODALITY = """ +7. You could use models from transformers library for this text dataset. +8. Use gpu if available for faster training. +""" + +IMAGE_MODALITY = """ +7. You could use models from torchvision library for this image dataset. +8. Use gpu if available for faster training. +""" + STACKING = """ 7. To avoid overfitting, train a weighted ensemble model such as StackingClassifier or StackingRegressor. 8. You could do some quick model prototyping to see which models work best and then use them in the ensemble. """ -SPECIAL_INSTRUCTIONS = {"ag": USE_AG, "stacking": STACKING} + +SPECIAL_INSTRUCTIONS = {"ag": USE_AG, "stacking": STACKING, "text": TEXT_MODALITY, "image": IMAGE_MODALITY} DI_INSTRUCTION = """ ## Attention diff --git a/expo/run_experiment.py b/expo/run_experiment.py index be028c47e..8dd66577c 100644 --- a/expo/run_experiment.py +++ b/expo/run_experiment.py @@ -3,10 +3,10 @@ import asyncio from expo.experimenter.aug import AugExperimenter from expo.experimenter.autogluon import GluonExperimenter +from expo.experimenter.autosklearn import AutoSklearnExperimenter from expo.experimenter.custom import CustomExperimenter from expo.experimenter.experimenter import Experimenter from expo.experimenter.mcts import MCTSExperimenter -from expo.experimenter.autosklearn import AutoSklearnExperimenter def get_args(): @@ -43,7 +43,7 @@ def get_di_args(parser): parser.add_argument("--reflection", dest="reflection", action="store_true") parser.add_argument("--no_reflection", dest="reflection", action="store_false") parser.add_argument("--num_experiments", type=int, default=1) - parser.add_argument("--special_instruction", type=str, default=None, choices=["ag", "stacking"]) + parser.add_argument("--special_instruction", type=str, default=None, choices=["ag", "stacking", "text", "image"]) parser.set_defaults(reflection=True)