diff --git a/expo/README.md b/expo/README.md index 011322897..3f9e630e5 100644 --- a/expo/README.md +++ b/expo/README.md @@ -215,6 +215,8 @@ #### Setup pip install -U pip pip install -U setuptools wheel pip install autogluon + +python run_expriment.py --exp_mode autogluon --task fashion_mnist ``` 提供github链接,并说明使用的命令以及参数设置 diff --git a/expo/experimenter/autogluon.py b/expo/experimenter/autogluon.py index 93dfdb4bc..4bcba432c 100644 --- a/expo/experimenter/autogluon.py +++ b/expo/experimenter/autogluon.py @@ -32,6 +32,77 @@ class AGRunner: test_preds = predictor.predict(test_data) return {"test_preds": test_preds, "dev_preds": dev_preds} + def run_images(self): + from autogluon.multimodal import MultiModalPredictor + target_col = self.state["dataset_config"]["target_col"] + train_path = self.datasets["train"] + dev_path = self.datasets["dev"] + dev_wo_target_path = self.datasets["dev_wo_target"] # Updated variable name + test_wo_target_path = self.datasets["test_wo_target"] + eval_metric = self.state["dataset_config"]["metric"].replace(" ", "_") + + # Load the datasets + train_data, dev_data, dev_wo_target_data, test_data = self.load_split_dataset( + train_path, dev_path, dev_wo_target_path, test_wo_target_path + ) + + # Create and fit the predictor + predictor = MultiModalPredictor( + label=target_col, + eval_metric=eval_metric, + path="AutogluonModels/ag-{}-{}".format(self.state["task"], datetime.now().strftime("%y%m%d_%H%M")), + ).fit(train_data=train_data, tuning_data=dev_data, time_limit=self.time_limit) + + # Make predictions on dev and test datasets + dev_preds = predictor.predict(dev_wo_target_data) + test_preds = predictor.predict(test_data) + + # Return predictions for dev and test datasets + return { + "dev_preds": dev_preds, + "test_preds": test_preds + } + + def load_split_dataset(self, train_path, dev_path, dev_wo_target_path, test_wo_target_path): + import os + import pandas as pd + """ + Loads training, dev, and test datasets from given file paths + + Args: + train_path (str): Path to the training dataset. + dev_path (str): Path to the dev dataset with target labels. + dev_wo_target_path (str): Path to the dev dataset without target labels. + test_wo_target_path (str): Path to the test dataset without target labels. + + Returns: + train_data (pd.DataFrame): Loaded training dataset with updated image paths. + dev_data (pd.DataFrame): Loaded dev dataset with updated image paths. + dev_wo_target_data (pd.DataFrame): Loaded dev dataset without target labels and updated image paths. + test_data (pd.DataFrame): Loaded test dataset with updated image paths. + """ + + # Define the root path to append + root_folder = os.path.join("F:/Download/Dataset/", self.state["task"]) + + # Load the datasets + train_data = pd.read_csv(train_path) + dev_data = pd.read_csv(dev_path) # Load dev dataset with target labels + dev_wo_target_data = pd.read_csv(dev_wo_target_path) # Load dev dataset without target labels + test_data = pd.read_csv(test_wo_target_path) + + + # Get the name of the first column (assuming it's the image path column) + + image_column = train_data.columns[0] + # Append root folder path to the image column in each dataset + train_data[image_column] = train_data[image_column].apply(lambda x: os.path.join(root_folder, x)) + dev_data[image_column] = dev_data[image_column].apply(lambda x: os.path.join(root_folder, x)) + dev_wo_target_data[image_column] = dev_wo_target_data[image_column].apply( + lambda x: os.path.join(root_folder, x)) + test_data[image_column] = test_data[image_column].apply(lambda x: os.path.join(root_folder, x)) + return train_data, dev_data, dev_wo_target_data, test_data + class GluonExperimenter(CustomExperimenter): result_path: str = "results/autogluon" @@ -41,7 +112,8 @@ class GluonExperimenter(CustomExperimenter): self.framework = AGRunner(self.state) async def run_experiment(self): - result = self.framework.run() + # result = self.framework.run() + result = self.framework.run_images() user_requirement = self.state["requirement"] dev_preds = result["dev_preds"] test_preds = result["test_preds"]